diff --git a/doc/hooks.md b/doc/hooks.md index 0fa9fb628..03ce26db0 100644 --- a/doc/hooks.md +++ b/doc/hooks.md @@ -89,3 +89,40 @@ The following variable are available on particular hooks: ### Examples See [sample hooks](https://github.com/github/gh-ost/tree/master/resources/hooks-sample), as `bash` implementation samples. + +### Embedded usage: registering Go callbacks + +When `gh-ost` is consumed as a library (importing `github.com/github/gh-ost/go/logic`), callers can register Go functions for any hook event instead of, or in addition to, the on-disk script contract. Implement the `base.Hooks` interface and assign it to `MigrationContext.Hooks` before calling `logic.NewMigrator`: + +```go +import ( + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/logic" +) + +const version = "1.1.8" + +type myHooks struct{} + +func (myHooks) OnSuccess(instantDDL bool) error { return nil } +func (myHooks) OnFailure() error { return nil } +// ... implement the remaining base.Hooks methods. + +ctx := base.NewMigrationContext() +// ... configure ctx (DatabaseName, OriginalTableName, AlterStatement, etc.) + +ctx.Hooks = &myHooks{} +m := logic.NewMigrator(ctx, version) +err := m.Migrate() +``` + +To run shell hooks from `--hooks-path` and Go callbacks together, wrap both in `logic.CompositeHooks`. Each member is invoked in order; the first non-nil error short-circuits, matching the script executor's behavior: + +```go +ctx.Hooks = logic.CompositeHooks{ + logic.NewHooksExecutor(ctx), // existing scripts under HooksPath + &myHooks{}, // additional Go callbacks +} +``` + +`MigrationContext.Hooks` is opt-in. When it is nil, `NewMigrator` wires the default script executor and behavior is identical to the CLI. Hooks are read once at `NewMigrator` time, so reassigning the field afterwards has no effect on the running migration. diff --git a/go/base/context.go b/go/base/context.go index c8eec7799..617e5bb13 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -156,6 +156,7 @@ type MigrationContext struct { HooksHintOwner string HooksHintToken string HooksStatusIntervalSec int64 + Hooks Hooks PanicOnWarnings bool Checkpoint bool CheckpointIntervalSeconds int64 diff --git a/go/base/hooks.go b/go/base/hooks.go new file mode 100644 index 000000000..688eade4b --- /dev/null +++ b/go/base/hooks.go @@ -0,0 +1,24 @@ +/* + Copyright 2026 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package base + +// Hooks is the set of lifecycle callbacks gh-ost invokes during a migration. +type Hooks interface { + OnStartup() error + OnValidated() error + OnRowCountComplete() error + OnBeforeRowCopy() error + OnRowCopyComplete() error + OnBeginPostponed() error + OnBeforeCutOver() error + OnInteractiveCommand(command string) error + OnSuccess(instantDDL bool) error + OnFailure() error + OnBatchCopyRetry(errorMessage string) error + OnStatus(statusMessage string) error + OnStopReplication() error + OnStartReplication() error +} diff --git a/go/logic/hooks.go b/go/logic/hooks.go index dfb18567d..1b36ede63 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -34,6 +34,177 @@ const ( onStartReplication = "gh-ost-on-start-replication" ) +// CompositeHooks invokes each member in order, returning the first non-nil error. +type CompositeHooks []base.Hooks + +func (c CompositeHooks) OnStartup() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnStartup(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnValidated() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnValidated(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnRowCountComplete() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnRowCountComplete(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnBeforeRowCopy() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnBeforeRowCopy(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnRowCopyComplete() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnRowCopyComplete(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnBeginPostponed() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnBeginPostponed(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnBeforeCutOver() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnBeforeCutOver(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnInteractiveCommand(command string) error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnInteractiveCommand(command); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnSuccess(instantDDL bool) error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnSuccess(instantDDL); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnFailure() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnFailure(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnBatchCopyRetry(errorMessage string) error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnBatchCopyRetry(errorMessage); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnStatus(statusMessage string) error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnStatus(statusMessage); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnStopReplication() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnStopReplication(); err != nil { + return err + } + } + return nil +} + +func (c CompositeHooks) OnStartReplication() error { + for _, h := range c { + if h == nil { + continue + } + if err := h.OnStartReplication(); err != nil { + return err + } + } + return nil +} + type HooksExecutor struct { migrationContext *base.MigrationContext writer io.Writer @@ -111,61 +282,61 @@ func (he *HooksExecutor) executeHooks(baseName string, extraVariables ...string) return nil } -func (he *HooksExecutor) onStartup() error { +func (he *HooksExecutor) OnStartup() error { return he.executeHooks(onStartup) } -func (he *HooksExecutor) onValidated() error { +func (he *HooksExecutor) OnValidated() error { return he.executeHooks(onValidated) } -func (he *HooksExecutor) onRowCountComplete() error { +func (he *HooksExecutor) OnRowCountComplete() error { return he.executeHooks(onRowCountComplete) } -func (he *HooksExecutor) onBeforeRowCopy() error { +func (he *HooksExecutor) OnBeforeRowCopy() error { return he.executeHooks(onBeforeRowCopy) } -func (he *HooksExecutor) onBatchCopyRetry(errorMessage string) error { +func (he *HooksExecutor) OnBatchCopyRetry(errorMessage string) error { v := fmt.Sprintf("GH_OST_LAST_BATCH_COPY_ERROR=%s", errorMessage) return he.executeHooks(onBatchCopyRetry, v) } -func (he *HooksExecutor) onRowCopyComplete() error { +func (he *HooksExecutor) OnRowCopyComplete() error { return he.executeHooks(onRowCopyComplete) } -func (he *HooksExecutor) onBeginPostponed() error { +func (he *HooksExecutor) OnBeginPostponed() error { return he.executeHooks(onBeginPostponed) } -func (he *HooksExecutor) onBeforeCutOver() error { +func (he *HooksExecutor) OnBeforeCutOver() error { return he.executeHooks(onBeforeCutOver) } -func (he *HooksExecutor) onInteractiveCommand(command string) error { +func (he *HooksExecutor) OnInteractiveCommand(command string) error { v := fmt.Sprintf("GH_OST_COMMAND='%s'", command) return he.executeHooks(onInteractiveCommand, v) } -func (he *HooksExecutor) onSuccess(instantDDL bool) error { +func (he *HooksExecutor) OnSuccess(instantDDL bool) error { v := fmt.Sprintf("GH_OST_INSTANT_DDL=%t", instantDDL) return he.executeHooks(onSuccess, v) } -func (he *HooksExecutor) onFailure() error { +func (he *HooksExecutor) OnFailure() error { return he.executeHooks(onFailure) } -func (he *HooksExecutor) onStatus(statusMessage string) error { +func (he *HooksExecutor) OnStatus(statusMessage string) error { v := fmt.Sprintf("GH_OST_STATUS='%s'", statusMessage) return he.executeHooks(onStatus, v) } -func (he *HooksExecutor) onStopReplication() error { +func (he *HooksExecutor) OnStopReplication() error { return he.executeHooks(onStopReplication) } -func (he *HooksExecutor) onStartReplication() error { +func (he *HooksExecutor) OnStartReplication() error { return he.executeHooks(onStartReplication) } diff --git a/go/logic/hooks_test.go b/go/logic/hooks_test.go index 94ea62b07..0729df31f 100644 --- a/go/logic/hooks_test.go +++ b/go/logic/hooks_test.go @@ -8,6 +8,7 @@ package logic import ( "bufio" "bytes" + "errors" "fmt" "os" "path/filepath" @@ -21,6 +22,140 @@ import ( "github.com/github/gh-ost/go/base" ) +type recordingHooks struct { + name string + calls *[]string + errOn string + errVal error +} + +func (r *recordingHooks) record(method string) error { + *r.calls = append(*r.calls, r.name+":"+method) + if r.errOn == method { + return r.errVal + } + return nil +} + +func (r *recordingHooks) OnStartup() error { return r.record("OnStartup") } +func (r *recordingHooks) OnValidated() error { return r.record("OnValidated") } +func (r *recordingHooks) OnRowCountComplete() error { return r.record("OnRowCountComplete") } +func (r *recordingHooks) OnBeforeRowCopy() error { return r.record("OnBeforeRowCopy") } +func (r *recordingHooks) OnRowCopyComplete() error { return r.record("OnRowCopyComplete") } +func (r *recordingHooks) OnBeginPostponed() error { return r.record("OnBeginPostponed") } +func (r *recordingHooks) OnBeforeCutOver() error { return r.record("OnBeforeCutOver") } +func (r *recordingHooks) OnInteractiveCommand(string) error { + return r.record("OnInteractiveCommand") +} +func (r *recordingHooks) OnSuccess(bool) error { return r.record("OnSuccess") } +func (r *recordingHooks) OnFailure() error { return r.record("OnFailure") } +func (r *recordingHooks) OnBatchCopyRetry(string) error { return r.record("OnBatchCopyRetry") } +func (r *recordingHooks) OnStatus(string) error { return r.record("OnStatus") } +func (r *recordingHooks) OnStopReplication() error { return r.record("OnStopReplication") } +func (r *recordingHooks) OnStartReplication() error { return r.record("OnStartReplication") } + +func TestCompositeHooks_FanOut(t *testing.T) { + var calls []string + composite := CompositeHooks{ + &recordingHooks{name: "a", calls: &calls}, + &recordingHooks{name: "b", calls: &calls}, + &recordingHooks{name: "c", calls: &calls}, + } + + require.NoError(t, composite.OnStartup()) + require.NoError(t, composite.OnBeforeCutOver()) + require.NoError(t, composite.OnSuccess(true)) + + require.Equal(t, []string{ + "a:OnStartup", "b:OnStartup", "c:OnStartup", + "a:OnBeforeCutOver", "b:OnBeforeCutOver", "c:OnBeforeCutOver", + "a:OnSuccess", "b:OnSuccess", "c:OnSuccess", + }, calls) +} + +func TestCompositeHooks_SkipsNil(t *testing.T) { + var calls []string + composite := CompositeHooks{ + nil, + &recordingHooks{name: "a", calls: &calls}, + nil, + &recordingHooks{name: "b", calls: &calls}, + } + + require.NoError(t, composite.OnStartup()) + require.NoError(t, composite.OnSuccess(false)) + require.Equal(t, []string{"a:OnStartup", "b:OnStartup", "a:OnSuccess", "b:OnSuccess"}, calls) +} + +func TestCompositeHooks_FirstErrorWins(t *testing.T) { + var calls []string + boom := errors.New("boom") + composite := CompositeHooks{ + &recordingHooks{name: "a", calls: &calls}, + &recordingHooks{name: "b", calls: &calls, errOn: "OnBeforeCutOver", errVal: boom}, + &recordingHooks{name: "c", calls: &calls}, // must not be called + } + + err := composite.OnBeforeCutOver() + require.ErrorIs(t, err, boom) + require.Equal(t, []string{"a:OnBeforeCutOver", "b:OnBeforeCutOver"}, calls) +} + +var ( + _ base.Hooks = (*HooksExecutor)(nil) + _ base.Hooks = (CompositeHooks)(nil) +) + +func TestCompositeHooks_WithShellExecutor(t *testing.T) { + ctx := base.NewMigrationContext() + ctx.DatabaseName = "test" + ctx.OriginalTableName = "tablename" + + hooksDir, err := os.MkdirTemp("", "TestCompositeHooks_WithShellExecutor") + require.NoError(t, err) + defer os.RemoveAll(hooksDir) + ctx.HooksPath = hooksDir + + sideEffect := filepath.Join(hooksDir, "ran.marker") + script := fmt.Sprintf("#!/bin/sh\ntouch %q\n", sideEffect) + require.NoError(t, os.WriteFile(filepath.Join(hooksDir, "gh-ost-on-startup"), []byte(script), 0o777)) + + shellExec := NewHooksExecutor(ctx) + shellExec.writer = new(bytes.Buffer) // suppress stderr noise + + var calls []string + goFake := &recordingHooks{name: "go", calls: &calls} + + composite := CompositeHooks{shellExec, goFake} + require.NoError(t, composite.OnStartup()) + + _, err = os.Stat(sideEffect) + require.NoError(t, err, "shell hook should have created marker file") + require.Equal(t, []string{"go:OnStartup"}, calls, "Go fake should have been invoked once") +} + +func TestNewMigrator_HooksFromContext(t *testing.T) { + t.Run("default-is-script-executor", func(t *testing.T) { + ctx := base.NewMigrationContext() + m := NewMigrator(ctx, "test") + _, ok := m.hooksExecutor.(*HooksExecutor) + require.True(t, ok) + }) + + t.Run("context-hooks-take-precedence", func(t *testing.T) { + var calls []string + fake := &recordingHooks{name: "fake", calls: &calls} + + ctx := base.NewMigrationContext() + ctx.Hooks = fake + m := NewMigrator(ctx, "test") + + require.Same(t, fake, m.hooksExecutor) + require.NoError(t, m.hooksExecutor.OnStartup()) + require.Equal(t, []string{"fake:OnStartup"}, calls) + }) +} + func TestHooksExecutorExecuteHooks(t *testing.T) { migrationContext := base.NewMigrationContext() migrationContext.AlterStatement = "ENGINE=InnoDB" diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 9dc6041f5..aaacbdd2e 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -84,7 +84,7 @@ type Migrator struct { eventsStreamer *EventsStreamer server *Server throttler *Throttler - hooksExecutor *HooksExecutor + hooksExecutor base.Hooks migrationContext *base.MigrationContext firstThrottlingCollected chan bool @@ -103,9 +103,13 @@ type Migrator struct { } func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { + hooks := context.Hooks + if hooks == nil { + hooks = NewHooksExecutor(context) + } migrator := &Migrator{ appVersion: appVersion, - hooksExecutor: NewHooksExecutor(context), + hooksExecutor: hooks, migrationContext: context, parser: sql.NewAlterTableParser(), ghostTableMigrated: make(chan bool), @@ -145,7 +149,7 @@ func (mgtr *Migrator) sleepWhileTrue(operation func() (bool, error)) error { func (mgtr *Migrator) retryBatchCopyWithHooks(operation func() error, notFatalHint ...bool) (err error) { wrappedOperation := func() error { if err := operation(); err != nil { - mgtr.hooksExecutor.onBatchCopyRetry(err.Error()) + mgtr.hooksExecutor.OnBatchCopyRetry(err.Error()) return err } return nil @@ -388,7 +392,7 @@ func (mgtr *Migrator) countTableRows() (err error) { if err := mgtr.inspector.CountTableRows(ctx); err != nil { return err } - if err := mgtr.hooksExecutor.onRowCountComplete(); err != nil { + if err := mgtr.hooksExecutor.OnRowCountComplete(); err != nil { return err } return nil @@ -456,7 +460,7 @@ func (mgtr *Migrator) Migrate() (err error) { go mgtr.listenOnPanicAbort() - if err := mgtr.hooksExecutor.onStartup(); err != nil { + if err := mgtr.hooksExecutor.OnStartup(); err != nil { return err } if err := mgtr.parser.ParseAlterStatement(mgtr.migrationContext.AlterStatement); err != nil { @@ -508,7 +512,7 @@ func (mgtr *Migrator) Migrate() (err error) { if err := mgtr.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.onSuccess(true); err != nil { + if err := mgtr.hooksExecutor.OnSuccess(true); err != nil { return err } mgtr.migrationContext.Log.Infof("Success! table %s.%s migrated instantly", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) @@ -565,7 +569,7 @@ func (mgtr *Migrator) Migrate() (err error) { } // Validation complete! We're good to execute this migration - if err := mgtr.hooksExecutor.onValidated(); err != nil { + if err := mgtr.hooksExecutor.OnValidated(); err != nil { return err } @@ -586,7 +590,7 @@ func (mgtr *Migrator) Migrate() (err error) { mgtr.initiateThrottler() - if err := mgtr.hooksExecutor.onBeforeRowCopy(); err != nil { + if err := mgtr.hooksExecutor.OnBeforeRowCopy(); err != nil { return err } go func() { @@ -609,7 +613,7 @@ func (mgtr *Migrator) Migrate() (err error) { if err := mgtr.checkAbort(); err != nil { return err } - if err := mgtr.hooksExecutor.onRowCopyComplete(); err != nil { + if err := mgtr.hooksExecutor.OnRowCopyComplete(); err != nil { return err } mgtr.printStatus(ForcePrintStatusRule) @@ -618,7 +622,7 @@ func (mgtr *Migrator) Migrate() (err error) { mgtr.migrationContext.Log.Info("stopping query for exact row count, because that can accidentally lock out the cut over") mgtr.migrationContext.CancelTableRowsCount() } - if err := mgtr.hooksExecutor.onBeforeCutOver(); err != nil { + if err := mgtr.hooksExecutor.OnBeforeCutOver(); err != nil { return err } var retrier func(func() error, ...bool) error @@ -644,7 +648,7 @@ func (mgtr *Migrator) Migrate() (err error) { if err := mgtr.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.onSuccess(false); err != nil { + if err := mgtr.hooksExecutor.OnSuccess(false); err != nil { return err } mgtr.migrationContext.Log.Infof("Done migrating %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) @@ -674,7 +678,7 @@ func (mgtr *Migrator) Revert() error { go mgtr.listenOnPanicAbort() - if err := mgtr.hooksExecutor.onStartup(); err != nil { + if err := mgtr.hooksExecutor.OnStartup(); err != nil { return err } if err := mgtr.validateAlterStatement(); err != nil { @@ -721,7 +725,7 @@ func (mgtr *Migrator) Revert() error { if err := mgtr.checkAbort(); err != nil { return err } - if err := mgtr.hooksExecutor.onValidated(); err != nil { + if err := mgtr.hooksExecutor.OnValidated(); err != nil { return err } if err := mgtr.initiateServer(); err != nil { @@ -748,7 +752,7 @@ func (mgtr *Migrator) Revert() error { } else { retrier = mgtr.retryOperation } - if err := mgtr.hooksExecutor.onBeforeCutOver(); err != nil { + if err := mgtr.hooksExecutor.OnBeforeCutOver(); err != nil { return err } if err := retrier(mgtr.cutOver); err != nil { @@ -758,7 +762,7 @@ func (mgtr *Migrator) Revert() error { if err := mgtr.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.onSuccess(false); err != nil { + if err := mgtr.hooksExecutor.OnSuccess(false); err != nil { return err } mgtr.migrationContext.Log.Infof("Done reverting %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) @@ -768,7 +772,7 @@ func (mgtr *Migrator) Revert() error { // ExecOnFailureHook executes the onFailure hook, and this method is provided as the only external // hook access point func (mgtr *Migrator) ExecOnFailureHook() (err error) { - return mgtr.hooksExecutor.onFailure() + return mgtr.hooksExecutor.OnFailure() } func (mgtr *Migrator) handleCutOverResult(cutOverError error) (err error) { @@ -786,7 +790,7 @@ func (mgtr *Migrator) handleCutOverResult(cutOverError error) (err error) { // the same cut-over phase as the master would use. That means we take locks // and swap the tables. // The difference is that we will later swap the tables back. - if err := mgtr.hooksExecutor.onStartReplication(); err != nil { + if err := mgtr.hooksExecutor.OnStartReplication(); err != nil { return mgtr.migrationContext.Log.Errore(err) } if mgtr.migrationContext.TestOnReplicaSkipReplicaStop { @@ -834,7 +838,7 @@ func (mgtr *Migrator) cutOver() (err error) { if base.FileExists(mgtr.migrationContext.PostponeCutOverFlagFile) { // Postpone file defined and exists! if atomic.LoadInt64(&mgtr.migrationContext.IsPostponingCutOver) == 0 { - if err := mgtr.hooksExecutor.onBeginPostponed(); err != nil { + if err := mgtr.hooksExecutor.OnBeginPostponed(); err != nil { return true, err } } @@ -855,7 +859,7 @@ func (mgtr *Migrator) cutOver() (err error) { // the same cut-over phase as the master would use. That means we take locks // and swap the tables. // The difference is that we will later swap the tables back. - if err := mgtr.hooksExecutor.onStopReplication(); err != nil { + if err := mgtr.hooksExecutor.OnStopReplication(); err != nil { return err } if mgtr.migrationContext.TestOnReplicaSkipReplicaStop { @@ -1410,7 +1414,7 @@ func (mgtr *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { hooksStatusIntervalSec := mgtr.migrationContext.HooksStatusIntervalSec if hooksStatusIntervalSec > 0 && elapsedSeconds%hooksStatusIntervalSec == 0 { - mgtr.hooksExecutor.onStatus(status) + mgtr.hooksExecutor.OnStatus(status) } } diff --git a/go/logic/server.go b/go/logic/server.go index 7fe171f6e..4705ba9b9 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -38,12 +38,12 @@ type Server struct { migrationContext *base.MigrationContext unixListener net.Listener tcpListener net.Listener - hooksExecutor *HooksExecutor + hooksExecutor base.Hooks printStatus printStatusFunc isCPUProfiling int64 } -func NewServer(migrationContext *base.MigrationContext, hooksExecutor *HooksExecutor, printStatus printStatusFunc) *Server { +func NewServer(migrationContext *base.MigrationContext, hooksExecutor base.Hooks, printStatus printStatusFunc) *Server { return &Server{ migrationContext: migrationContext, hooksExecutor: hooksExecutor, @@ -206,7 +206,7 @@ func (srv *Server) applyServerCommand(command string, writer *bufio.Writer) (pri argIsQuestion := (arg == "?") throttleHint := "# Note: you may only throttle for as long as your binary logs are not purged" - if err := srv.hooksExecutor.onInteractiveCommand(command); err != nil { + if err := srv.hooksExecutor.OnInteractiveCommand(command); err != nil { return NoPrintStatusRule, err }