From 77d13ccdae06351ec3dfc5b5a2b09343c0e62ac2 Mon Sep 17 00:00:00 2001 From: Jacob Olson Date: Thu, 14 May 2026 11:16:28 -0600 Subject: [PATCH 1/2] Add MigrationContext.Hooks for in-process hook implementations gh-ost's only hook extension point is on-disk scripts globbed from --hooks-path. Library callers that embed Migrator must either ship scripts and their dependencies alongside their binary or maintain a parallel Go layer that bridges script side effects back into the host application. Introduce a Hooks interface in go/base with one method per lifecycle event, and an optional MigrationContext.Hooks field. NewMigrator reads the field once at construction and falls back to the existing HooksExecutor when unset, so CLI behavior is unchanged. A CompositeHooks helper in go/logic lets callers run the on-disk script executor and their own Go implementation side-by-side. HooksExecutor's previously package-private method names are renamed (onStartup -> OnStartup, etc.) so external types can satisfy the interface. The struct and constructor were already exported but the methods weren't, so no usable external API is displaced. --- doc/hooks.md | 32 +++++++++ go/base/context.go | 1 + go/base/hooks.go | 24 +++++++ go/logic/hooks.go | 157 +++++++++++++++++++++++++++++++++++++---- go/logic/hooks_test.go | 121 +++++++++++++++++++++++++++++++ go/logic/migrator.go | 44 ++++++------ go/logic/server.go | 6 +- 7 files changed, 348 insertions(+), 37 deletions(-) create mode 100644 go/base/hooks.go diff --git a/doc/hooks.md b/doc/hooks.md index 0fa9fb628..9b331d9be 100644 --- a/doc/hooks.md +++ b/doc/hooks.md @@ -89,3 +89,35 @@ The following variable are available on particular hooks: ### Examples See [sample hooks](https://github.com/github/gh-ost/tree/master/resources/hooks-sample), as `bash` implementation samples. + +### Embedded usage: registering Go callbacks + +When `gh-ost` is consumed as a library (importing `github.com/github/gh-ost/go/logic`), callers can register Go functions for any hook event instead of, or in addition to, the on-disk script contract. Implement the `base.Hooks` interface and assign it to `MigrationContext.Hooks` before calling `logic.NewMigrator`: + +```go +import ( + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/logic" +) + +type myHooks struct{} + +func (myHooks) OnSuccess(instantDDL bool) error { return nil } +func (myHooks) OnFailure() error { return nil } +// ... implement the remaining base.Hooks methods. + +ctx.Hooks = &myHooks{} +m := logic.NewMigrator(ctx, version) +err := m.Migrate() +``` + +To run shell hooks from `--hooks-path` and Go callbacks together, wrap both in `logic.CompositeHooks`. Each member is invoked in order; the first non-nil error short-circuits, matching the script executor's behavior: + +```go +ctx.Hooks = logic.CompositeHooks{ + logic.NewHooksExecutor(ctx), // existing scripts under HooksPath + &myHooks{}, // additional Go callbacks +} +``` + +`MigrationContext.Hooks` is opt-in. When it is nil, `NewMigrator` wires the default script executor and behavior is identical to the CLI. Hooks are read once at `NewMigrator` time, so reassigning the field afterwards has no effect on the running migration. diff --git a/go/base/context.go b/go/base/context.go index c8eec7799..617e5bb13 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -156,6 +156,7 @@ type MigrationContext struct { HooksHintOwner string HooksHintToken string HooksStatusIntervalSec int64 + Hooks Hooks PanicOnWarnings bool Checkpoint bool CheckpointIntervalSeconds int64 diff --git a/go/base/hooks.go b/go/base/hooks.go new file mode 100644 index 000000000..688eade4b --- /dev/null +++ b/go/base/hooks.go @@ -0,0 +1,24 @@ +/* + Copyright 2026 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package base + +// Hooks is the set of lifecycle callbacks gh-ost invokes during a migration. +type Hooks interface { + OnStartup() error + OnValidated() error + OnRowCountComplete() error + OnBeforeRowCopy() error + OnRowCopyComplete() error + OnBeginPostponed() error + OnBeforeCutOver() error + OnInteractiveCommand(command string) error + OnSuccess(instantDDL bool) error + OnFailure() error + OnBatchCopyRetry(errorMessage string) error + OnStatus(statusMessage string) error + OnStopReplication() error + OnStartReplication() error +} diff --git a/go/logic/hooks.go b/go/logic/hooks.go index dfb18567d..c2c088e6a 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -34,6 +34,135 @@ const ( onStartReplication = "gh-ost-on-start-replication" ) +// CompositeHooks invokes each member in order, returning the first non-nil error. +type CompositeHooks []base.Hooks + +func (c CompositeHooks) OnStartup() error { + for _, h := range c { + if err := h.OnStartup(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnValidated() error { + for _, h := range c { + if err := h.OnValidated(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnRowCountComplete() error { + for _, h := range c { + if err := h.OnRowCountComplete(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnBeforeRowCopy() error { + for _, h := range c { + if err := h.OnBeforeRowCopy(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnRowCopyComplete() error { + for _, h := range c { + if err := h.OnRowCopyComplete(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnBeginPostponed() error { + for _, h := range c { + if err := h.OnBeginPostponed(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnBeforeCutOver() error { + for _, h := range c { + if err := h.OnBeforeCutOver(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnInteractiveCommand(command string) error { + for _, h := range c { + if err := h.OnInteractiveCommand(command); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnSuccess(instantDDL bool) error { + for _, h := range c { + if err := h.OnSuccess(instantDDL); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnFailure() error { + for _, h := range c { + if err := h.OnFailure(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnBatchCopyRetry(errorMessage string) error { + for _, h := range c { + if err := h.OnBatchCopyRetry(errorMessage); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnStatus(statusMessage string) error { + for _, h := range c { + if err := h.OnStatus(statusMessage); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnStopReplication() error { + for _, h := range c { + if err := h.OnStopReplication(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnStartReplication() error { + for _, h := range c { + if err := h.OnStartReplication(); err != nil { + return err + } + } + return nil +} + type HooksExecutor struct { migrationContext *base.MigrationContext writer io.Writer @@ -111,61 +240,61 @@ func (he *HooksExecutor) executeHooks(baseName string, extraVariables ...string) return nil } -func (he *HooksExecutor) onStartup() error { +func (he *HooksExecutor) OnStartup() error { return he.executeHooks(onStartup) } -func (he *HooksExecutor) onValidated() error { +func (he *HooksExecutor) OnValidated() error { return he.executeHooks(onValidated) } -func (he *HooksExecutor) onRowCountComplete() error { +func (he *HooksExecutor) OnRowCountComplete() error { return he.executeHooks(onRowCountComplete) } -func (he *HooksExecutor) onBeforeRowCopy() error { +func (he *HooksExecutor) OnBeforeRowCopy() error { return he.executeHooks(onBeforeRowCopy) } -func (he *HooksExecutor) onBatchCopyRetry(errorMessage string) error { +func (he *HooksExecutor) OnBatchCopyRetry(errorMessage string) error { v := fmt.Sprintf("GH_OST_LAST_BATCH_COPY_ERROR=%s", errorMessage) return he.executeHooks(onBatchCopyRetry, v) } -func (he *HooksExecutor) onRowCopyComplete() error { +func (he *HooksExecutor) OnRowCopyComplete() error { return he.executeHooks(onRowCopyComplete) } -func (he *HooksExecutor) onBeginPostponed() error { +func (he *HooksExecutor) OnBeginPostponed() error { return he.executeHooks(onBeginPostponed) } -func (he *HooksExecutor) onBeforeCutOver() error { +func (he *HooksExecutor) OnBeforeCutOver() error { return he.executeHooks(onBeforeCutOver) } -func (he *HooksExecutor) onInteractiveCommand(command string) error { +func (he *HooksExecutor) OnInteractiveCommand(command string) error { v := fmt.Sprintf("GH_OST_COMMAND='%s'", command) return he.executeHooks(onInteractiveCommand, v) } -func (he *HooksExecutor) onSuccess(instantDDL bool) error { +func (he *HooksExecutor) OnSuccess(instantDDL bool) error { v := fmt.Sprintf("GH_OST_INSTANT_DDL=%t", instantDDL) return he.executeHooks(onSuccess, v) } -func (he *HooksExecutor) onFailure() error { +func (he *HooksExecutor) OnFailure() error { return he.executeHooks(onFailure) } -func (he *HooksExecutor) onStatus(statusMessage string) error { +func (he *HooksExecutor) OnStatus(statusMessage string) error { v := fmt.Sprintf("GH_OST_STATUS='%s'", statusMessage) return he.executeHooks(onStatus, v) } -func (he *HooksExecutor) onStopReplication() error { +func (he *HooksExecutor) OnStopReplication() error { return he.executeHooks(onStopReplication) } -func (he *HooksExecutor) onStartReplication() error { +func (he *HooksExecutor) OnStartReplication() error { return he.executeHooks(onStartReplication) } diff --git a/go/logic/hooks_test.go b/go/logic/hooks_test.go index 94ea62b07..2ce82ab72 100644 --- a/go/logic/hooks_test.go +++ b/go/logic/hooks_test.go @@ -8,6 +8,7 @@ package logic import ( "bufio" "bytes" + "errors" "fmt" "os" "path/filepath" @@ -21,6 +22,126 @@ import ( "github.com/github/gh-ost/go/base" ) +type recordingHooks struct { + name string + calls *[]string + errOn string + errVal error +} + +func (r *recordingHooks) record(method string) error { + *r.calls = append(*r.calls, r.name+":"+method) + if r.errOn == method { + return r.errVal + } + return nil +} + +func (r *recordingHooks) OnStartup() error { return r.record("OnStartup") } +func (r *recordingHooks) OnValidated() error { return r.record("OnValidated") } +func (r *recordingHooks) OnRowCountComplete() error { return r.record("OnRowCountComplete") } +func (r *recordingHooks) OnBeforeRowCopy() error { return r.record("OnBeforeRowCopy") } +func (r *recordingHooks) OnRowCopyComplete() error { return r.record("OnRowCopyComplete") } +func (r *recordingHooks) OnBeginPostponed() error { return r.record("OnBeginPostponed") } +func (r *recordingHooks) OnBeforeCutOver() error { return r.record("OnBeforeCutOver") } +func (r *recordingHooks) OnInteractiveCommand(string) error { + return r.record("OnInteractiveCommand") +} +func (r *recordingHooks) OnSuccess(bool) error { return r.record("OnSuccess") } +func (r *recordingHooks) OnFailure() error { return r.record("OnFailure") } +func (r *recordingHooks) OnBatchCopyRetry(string) error { return r.record("OnBatchCopyRetry") } +func (r *recordingHooks) OnStatus(string) error { return r.record("OnStatus") } +func (r *recordingHooks) OnStopReplication() error { return r.record("OnStopReplication") } +func (r *recordingHooks) OnStartReplication() error { return r.record("OnStartReplication") } + +func TestCompositeHooks_FanOut(t *testing.T) { + var calls []string + composite := CompositeHooks{ + &recordingHooks{name: "a", calls: &calls}, + &recordingHooks{name: "b", calls: &calls}, + &recordingHooks{name: "c", calls: &calls}, + } + + require.NoError(t, composite.OnStartup()) + require.NoError(t, composite.OnBeforeCutOver()) + require.NoError(t, composite.OnSuccess(true)) + + require.Equal(t, []string{ + "a:OnStartup", "b:OnStartup", "c:OnStartup", + "a:OnBeforeCutOver", "b:OnBeforeCutOver", "c:OnBeforeCutOver", + "a:OnSuccess", "b:OnSuccess", "c:OnSuccess", + }, calls) +} + +func TestCompositeHooks_FirstErrorWins(t *testing.T) { + var calls []string + boom := errors.New("boom") + composite := CompositeHooks{ + &recordingHooks{name: "a", calls: &calls}, + &recordingHooks{name: "b", calls: &calls, errOn: "OnBeforeCutOver", errVal: boom}, + &recordingHooks{name: "c", calls: &calls}, // must not be called + } + + err := composite.OnBeforeCutOver() + require.ErrorIs(t, err, boom) + require.Equal(t, []string{"a:OnBeforeCutOver", "b:OnBeforeCutOver"}, calls) +} + +var ( + _ base.Hooks = (*HooksExecutor)(nil) + _ base.Hooks = (CompositeHooks)(nil) +) + +func TestCompositeHooks_WithShellExecutor(t *testing.T) { + ctx := base.NewMigrationContext() + ctx.DatabaseName = "test" + ctx.OriginalTableName = "tablename" + + hooksDir, err := os.MkdirTemp("", "TestCompositeHooks_WithShellExecutor") + require.NoError(t, err) + defer os.RemoveAll(hooksDir) + ctx.HooksPath = hooksDir + + sideEffect := filepath.Join(hooksDir, "ran.marker") + script := fmt.Sprintf("#!/bin/sh\ntouch %q\n", sideEffect) + require.NoError(t, os.WriteFile(filepath.Join(hooksDir, "gh-ost-on-startup"), []byte(script), 0o777)) + + shellExec := NewHooksExecutor(ctx) + shellExec.writer = new(bytes.Buffer) // suppress stderr noise + + var calls []string + goFake := &recordingHooks{name: "go", calls: &calls} + + composite := CompositeHooks{shellExec, goFake} + require.NoError(t, composite.OnStartup()) + + _, err = os.Stat(sideEffect) + require.NoError(t, err, "shell hook should have created marker file") + require.Equal(t, []string{"go:OnStartup"}, calls, "Go fake should have been invoked once") +} + +func TestNewMigrator_HooksFromContext(t *testing.T) { + t.Run("default-is-script-executor", func(t *testing.T) { + ctx := base.NewMigrationContext() + m := NewMigrator(ctx, "test") + _, ok := m.hooksExecutor.(*HooksExecutor) + require.True(t, ok) + }) + + t.Run("context-hooks-take-precedence", func(t *testing.T) { + var calls []string + fake := &recordingHooks{name: "fake", calls: &calls} + + ctx := base.NewMigrationContext() + ctx.Hooks = fake + m := NewMigrator(ctx, "test") + + require.Same(t, fake, m.hooksExecutor) + require.NoError(t, m.hooksExecutor.OnStartup()) + require.Equal(t, []string{"fake:OnStartup"}, calls) + }) +} + func TestHooksExecutorExecuteHooks(t *testing.T) { migrationContext := base.NewMigrationContext() migrationContext.AlterStatement = "ENGINE=InnoDB" diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 9dc6041f5..aaacbdd2e 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -84,7 +84,7 @@ type Migrator struct { eventsStreamer *EventsStreamer server *Server throttler *Throttler - hooksExecutor *HooksExecutor + hooksExecutor base.Hooks migrationContext *base.MigrationContext firstThrottlingCollected chan bool @@ -103,9 +103,13 @@ type Migrator struct { } func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { + hooks := context.Hooks + if hooks == nil { + hooks = NewHooksExecutor(context) + } migrator := &Migrator{ appVersion: appVersion, - hooksExecutor: NewHooksExecutor(context), + hooksExecutor: hooks, migrationContext: context, parser: sql.NewAlterTableParser(), ghostTableMigrated: make(chan bool), @@ -145,7 +149,7 @@ func (mgtr *Migrator) sleepWhileTrue(operation func() (bool, error)) error { func (mgtr *Migrator) retryBatchCopyWithHooks(operation func() error, notFatalHint ...bool) (err error) { wrappedOperation := func() error { if err := operation(); err != nil { - mgtr.hooksExecutor.onBatchCopyRetry(err.Error()) + mgtr.hooksExecutor.OnBatchCopyRetry(err.Error()) return err } return nil @@ -388,7 +392,7 @@ func (mgtr *Migrator) countTableRows() (err error) { if err := mgtr.inspector.CountTableRows(ctx); err != nil { return err } - if err := mgtr.hooksExecutor.onRowCountComplete(); err != nil { + if err := mgtr.hooksExecutor.OnRowCountComplete(); err != nil { return err } return nil @@ -456,7 +460,7 @@ func (mgtr *Migrator) Migrate() (err error) { go mgtr.listenOnPanicAbort() - if err := mgtr.hooksExecutor.onStartup(); err != nil { + if err := mgtr.hooksExecutor.OnStartup(); err != nil { return err } if err := mgtr.parser.ParseAlterStatement(mgtr.migrationContext.AlterStatement); err != nil { @@ -508,7 +512,7 @@ func (mgtr *Migrator) Migrate() (err error) { if err := mgtr.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.onSuccess(true); err != nil { + if err := mgtr.hooksExecutor.OnSuccess(true); err != nil { return err } mgtr.migrationContext.Log.Infof("Success! table %s.%s migrated instantly", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) @@ -565,7 +569,7 @@ func (mgtr *Migrator) Migrate() (err error) { } // Validation complete! We're good to execute this migration - if err := mgtr.hooksExecutor.onValidated(); err != nil { + if err := mgtr.hooksExecutor.OnValidated(); err != nil { return err } @@ -586,7 +590,7 @@ func (mgtr *Migrator) Migrate() (err error) { mgtr.initiateThrottler() - if err := mgtr.hooksExecutor.onBeforeRowCopy(); err != nil { + if err := mgtr.hooksExecutor.OnBeforeRowCopy(); err != nil { return err } go func() { @@ -609,7 +613,7 @@ func (mgtr *Migrator) Migrate() (err error) { if err := mgtr.checkAbort(); err != nil { return err } - if err := mgtr.hooksExecutor.onRowCopyComplete(); err != nil { + if err := mgtr.hooksExecutor.OnRowCopyComplete(); err != nil { return err } mgtr.printStatus(ForcePrintStatusRule) @@ -618,7 +622,7 @@ func (mgtr *Migrator) Migrate() (err error) { mgtr.migrationContext.Log.Info("stopping query for exact row count, because that can accidentally lock out the cut over") mgtr.migrationContext.CancelTableRowsCount() } - if err := mgtr.hooksExecutor.onBeforeCutOver(); err != nil { + if err := mgtr.hooksExecutor.OnBeforeCutOver(); err != nil { return err } var retrier func(func() error, ...bool) error @@ -644,7 +648,7 @@ func (mgtr *Migrator) Migrate() (err error) { if err := mgtr.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.onSuccess(false); err != nil { + if err := mgtr.hooksExecutor.OnSuccess(false); err != nil { return err } mgtr.migrationContext.Log.Infof("Done migrating %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) @@ -674,7 +678,7 @@ func (mgtr *Migrator) Revert() error { go mgtr.listenOnPanicAbort() - if err := mgtr.hooksExecutor.onStartup(); err != nil { + if err := mgtr.hooksExecutor.OnStartup(); err != nil { return err } if err := mgtr.validateAlterStatement(); err != nil { @@ -721,7 +725,7 @@ func (mgtr *Migrator) Revert() error { if err := mgtr.checkAbort(); err != nil { return err } - if err := mgtr.hooksExecutor.onValidated(); err != nil { + if err := mgtr.hooksExecutor.OnValidated(); err != nil { return err } if err := mgtr.initiateServer(); err != nil { @@ -748,7 +752,7 @@ func (mgtr *Migrator) Revert() error { } else { retrier = mgtr.retryOperation } - if err := mgtr.hooksExecutor.onBeforeCutOver(); err != nil { + if err := mgtr.hooksExecutor.OnBeforeCutOver(); err != nil { return err } if err := retrier(mgtr.cutOver); err != nil { @@ -758,7 +762,7 @@ func (mgtr *Migrator) Revert() error { if err := mgtr.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.onSuccess(false); err != nil { + if err := mgtr.hooksExecutor.OnSuccess(false); err != nil { return err } mgtr.migrationContext.Log.Infof("Done reverting %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) @@ -768,7 +772,7 @@ func (mgtr *Migrator) Revert() error { // ExecOnFailureHook executes the onFailure hook, and this method is provided as the only external // hook access point func (mgtr *Migrator) ExecOnFailureHook() (err error) { - return mgtr.hooksExecutor.onFailure() + return mgtr.hooksExecutor.OnFailure() } func (mgtr *Migrator) handleCutOverResult(cutOverError error) (err error) { @@ -786,7 +790,7 @@ func (mgtr *Migrator) handleCutOverResult(cutOverError error) (err error) { // the same cut-over phase as the master would use. That means we take locks // and swap the tables. // The difference is that we will later swap the tables back. - if err := mgtr.hooksExecutor.onStartReplication(); err != nil { + if err := mgtr.hooksExecutor.OnStartReplication(); err != nil { return mgtr.migrationContext.Log.Errore(err) } if mgtr.migrationContext.TestOnReplicaSkipReplicaStop { @@ -834,7 +838,7 @@ func (mgtr *Migrator) cutOver() (err error) { if base.FileExists(mgtr.migrationContext.PostponeCutOverFlagFile) { // Postpone file defined and exists! if atomic.LoadInt64(&mgtr.migrationContext.IsPostponingCutOver) == 0 { - if err := mgtr.hooksExecutor.onBeginPostponed(); err != nil { + if err := mgtr.hooksExecutor.OnBeginPostponed(); err != nil { return true, err } } @@ -855,7 +859,7 @@ func (mgtr *Migrator) cutOver() (err error) { // the same cut-over phase as the master would use. That means we take locks // and swap the tables. // The difference is that we will later swap the tables back. - if err := mgtr.hooksExecutor.onStopReplication(); err != nil { + if err := mgtr.hooksExecutor.OnStopReplication(); err != nil { return err } if mgtr.migrationContext.TestOnReplicaSkipReplicaStop { @@ -1410,7 +1414,7 @@ func (mgtr *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { hooksStatusIntervalSec := mgtr.migrationContext.HooksStatusIntervalSec if hooksStatusIntervalSec > 0 && elapsedSeconds%hooksStatusIntervalSec == 0 { - mgtr.hooksExecutor.onStatus(status) + mgtr.hooksExecutor.OnStatus(status) } } diff --git a/go/logic/server.go b/go/logic/server.go index 7fe171f6e..4705ba9b9 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -38,12 +38,12 @@ type Server struct { migrationContext *base.MigrationContext unixListener net.Listener tcpListener net.Listener - hooksExecutor *HooksExecutor + hooksExecutor base.Hooks printStatus printStatusFunc isCPUProfiling int64 } -func NewServer(migrationContext *base.MigrationContext, hooksExecutor *HooksExecutor, printStatus printStatusFunc) *Server { +func NewServer(migrationContext *base.MigrationContext, hooksExecutor base.Hooks, printStatus printStatusFunc) *Server { return &Server{ migrationContext: migrationContext, hooksExecutor: hooksExecutor, @@ -206,7 +206,7 @@ func (srv *Server) applyServerCommand(command string, writer *bufio.Writer) (pri argIsQuestion := (arg == "?") throttleHint := "# Note: you may only throttle for as long as your binary logs are not purged" - if err := srv.hooksExecutor.onInteractiveCommand(command); err != nil { + if err := srv.hooksExecutor.OnInteractiveCommand(command); err != nil { return NoPrintStatusRule, err } From 0927d5f6636e18db54740b7f7bc4ef18da780a3e Mon Sep 17 00:00:00 2001 From: Jacob Olson Date: Thu, 14 May 2026 11:37:57 -0600 Subject: [PATCH 2/2] Skip nil entries in CompositeHooks and self-contain doc example Address PR review feedback: - CompositeHooks.OnX methods skip nil members instead of panicking, allowing callers to conditionally append optional hooks. - doc/hooks.md embedded-usage snippet now defines ctx and version so it is self-contained. --- doc/hooks.md | 5 +++++ go/logic/hooks.go | 42 ++++++++++++++++++++++++++++++++++++++++++ go/logic/hooks_test.go | 14 ++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/doc/hooks.md b/doc/hooks.md index 9b331d9be..03ce26db0 100644 --- a/doc/hooks.md +++ b/doc/hooks.md @@ -100,12 +100,17 @@ import ( "github.com/github/gh-ost/go/logic" ) +const version = "1.1.8" + type myHooks struct{} func (myHooks) OnSuccess(instantDDL bool) error { return nil } func (myHooks) OnFailure() error { return nil } // ... implement the remaining base.Hooks methods. +ctx := base.NewMigrationContext() +// ... configure ctx (DatabaseName, OriginalTableName, AlterStatement, etc.) + ctx.Hooks = &myHooks{} m := logic.NewMigrator(ctx, version) err := m.Migrate() diff --git a/go/logic/hooks.go b/go/logic/hooks.go index c2c088e6a..1b36ede63 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -39,6 +39,9 @@ type CompositeHooks []base.Hooks func (c CompositeHooks) OnStartup() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnStartup(); err != nil { return err } @@ -48,6 +51,9 @@ func (c CompositeHooks) OnStartup() error { func (c CompositeHooks) OnValidated() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnValidated(); err != nil { return err } @@ -57,6 +63,9 @@ func (c CompositeHooks) OnValidated() error { func (c CompositeHooks) OnRowCountComplete() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnRowCountComplete(); err != nil { return err } @@ -66,6 +75,9 @@ func (c CompositeHooks) OnRowCountComplete() error { func (c CompositeHooks) OnBeforeRowCopy() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnBeforeRowCopy(); err != nil { return err } @@ -75,6 +87,9 @@ func (c CompositeHooks) OnBeforeRowCopy() error { func (c CompositeHooks) OnRowCopyComplete() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnRowCopyComplete(); err != nil { return err } @@ -84,6 +99,9 @@ func (c CompositeHooks) OnRowCopyComplete() error { func (c CompositeHooks) OnBeginPostponed() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnBeginPostponed(); err != nil { return err } @@ -93,6 +111,9 @@ func (c CompositeHooks) OnBeginPostponed() error { func (c CompositeHooks) OnBeforeCutOver() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnBeforeCutOver(); err != nil { return err } @@ -102,6 +123,9 @@ func (c CompositeHooks) OnBeforeCutOver() error { func (c CompositeHooks) OnInteractiveCommand(command string) error { for _, h := range c { + if h == nil { + continue + } if err := h.OnInteractiveCommand(command); err != nil { return err } @@ -111,6 +135,9 @@ func (c CompositeHooks) OnInteractiveCommand(command string) error { func (c CompositeHooks) OnSuccess(instantDDL bool) error { for _, h := range c { + if h == nil { + continue + } if err := h.OnSuccess(instantDDL); err != nil { return err } @@ -120,6 +147,9 @@ func (c CompositeHooks) OnSuccess(instantDDL bool) error { func (c CompositeHooks) OnFailure() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnFailure(); err != nil { return err } @@ -129,6 +159,9 @@ func (c CompositeHooks) OnFailure() error { func (c CompositeHooks) OnBatchCopyRetry(errorMessage string) error { for _, h := range c { + if h == nil { + continue + } if err := h.OnBatchCopyRetry(errorMessage); err != nil { return err } @@ -138,6 +171,9 @@ func (c CompositeHooks) OnBatchCopyRetry(errorMessage string) error { func (c CompositeHooks) OnStatus(statusMessage string) error { for _, h := range c { + if h == nil { + continue + } if err := h.OnStatus(statusMessage); err != nil { return err } @@ -147,6 +183,9 @@ func (c CompositeHooks) OnStatus(statusMessage string) error { func (c CompositeHooks) OnStopReplication() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnStopReplication(); err != nil { return err } @@ -156,6 +195,9 @@ func (c CompositeHooks) OnStopReplication() error { func (c CompositeHooks) OnStartReplication() error { for _, h := range c { + if h == nil { + continue + } if err := h.OnStartReplication(); err != nil { return err } diff --git a/go/logic/hooks_test.go b/go/logic/hooks_test.go index 2ce82ab72..0729df31f 100644 --- a/go/logic/hooks_test.go +++ b/go/logic/hooks_test.go @@ -73,6 +73,20 @@ func TestCompositeHooks_FanOut(t *testing.T) { }, calls) } +func TestCompositeHooks_SkipsNil(t *testing.T) { + var calls []string + composite := CompositeHooks{ + nil, + &recordingHooks{name: "a", calls: &calls}, + nil, + &recordingHooks{name: "b", calls: &calls}, + } + + require.NoError(t, composite.OnStartup()) + require.NoError(t, composite.OnSuccess(false)) + require.Equal(t, []string{"a:OnStartup", "b:OnStartup", "a:OnSuccess", "b:OnSuccess"}, calls) +} + func TestCompositeHooks_FirstErrorWins(t *testing.T) { var calls []string boom := errors.New("boom")