diff --git a/go.mod b/go.mod index 2eb240ea39..caa69ba2f6 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,9 @@ require ( github.com/containerd/errdefs v1.0.0 github.com/containerd/errdefs/pkg v0.3.0 github.com/containerd/go-runc v1.1.0 + github.com/containerd/log v0.1.0 github.com/containerd/platforms v1.0.0-rc.1 + github.com/containerd/plugin v1.0.0 github.com/containerd/ttrpc v1.2.7 github.com/containerd/typeurl/v2 v2.2.3 github.com/google/go-cmp v0.7.0 @@ -76,8 +78,6 @@ require ( github.com/checkpoint-restore/go-criu/v6 v6.3.0 // indirect github.com/containerd/continuity v0.4.5 // indirect github.com/containerd/fifo v1.1.0 // indirect - github.com/containerd/log v0.1.0 // indirect - github.com/containerd/plugin v1.0.0 // indirect github.com/containerd/protobuild v0.3.0 // indirect github.com/containerd/stargz-snapshotter/estargz v0.15.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect diff --git a/pkg/shim/publisher.go b/pkg/shim/publisher.go new file mode 100644 index 0000000000..5318c6cfe0 --- /dev/null +++ b/pkg/shim/publisher.go @@ -0,0 +1,186 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "errors" + "sync" + "time" + + v1 "github.com/containerd/containerd/api/services/ttrpc/events/v1" + "github.com/containerd/containerd/api/types" + "github.com/containerd/containerd/v2/core/events" + "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/containerd/containerd/v2/pkg/protobuf" + "github.com/containerd/containerd/v2/pkg/ttrpcutil" + "github.com/containerd/log" + "github.com/containerd/ttrpc" + "github.com/containerd/typeurl/v2" +) + +const ( + queueSize = 2048 + maxRequeue = 5 +) + +type item struct { + ev *types.Envelope + ctx context.Context + count int +} + +type publisherConfig struct { + ttrpcOpts []ttrpc.ClientOpts +} + +type PublisherOpts func(*publisherConfig) + +func WithPublishTTRPCOpts(opts ...ttrpc.ClientOpts) PublisherOpts { + return func(cfg *publisherConfig) { + cfg.ttrpcOpts = append(cfg.ttrpcOpts, opts...) + } +} + +// NewPublisher creates a new remote events publisher +func NewPublisher(address string, opts ...PublisherOpts) (*RemoteEventsPublisher, error) { + client, err := ttrpcutil.NewClient(address) + if err != nil { + return nil, err + } + + l := &RemoteEventsPublisher{ + client: client, + closed: make(chan struct{}), + requeue: make(chan *item, queueSize), + } + + go l.processQueue() + return l, nil +} + +// RemoteEventsPublisher forwards events to a ttrpc server +type RemoteEventsPublisher struct { + client *ttrpcutil.Client + closed chan struct{} + closer sync.Once + requeue chan *item +} + +// Done returns a channel which closes when done +func (l *RemoteEventsPublisher) Done() <-chan struct{} { + return l.closed +} + +// Close closes the remote connection and closes the done channel +func (l *RemoteEventsPublisher) Close() (err error) { + err = l.client.Close() + l.closer.Do(func() { + close(l.closed) + }) + return err +} + +func (l *RemoteEventsPublisher) processQueue() { + for i := range l.requeue { + if i.count > maxRequeue { + log.L.Errorf("evicting %s from queue because of retry count", i.ev.Topic) + // drop the event + continue + } + + if err := l.forwardRequest(i.ctx, &v1.ForwardRequest{Envelope: i.ev}); err != nil { + log.L.WithError(err).Error("forward event") + l.queue(i) + } + } +} + +func (l *RemoteEventsPublisher) queue(i *item) { + go func() { + i.count++ + // re-queue after a short delay + time.Sleep(time.Duration(1*i.count) * time.Second) + l.requeue <- i + }() +} + +// Publish publishes the event by forwarding it to the configured ttrpc server +func (l *RemoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error { + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return err + } + evt, err := typeurl.MarshalAnyToProto(event) + if err != nil { + return err + } + i := &item{ + ev: &types.Envelope{ + Timestamp: protobuf.ToTimestamp(time.Now()), + Namespace: ns, + Topic: topic, + Event: evt, + }, + ctx: ctx, + } + + if err := l.forwardRequest(i.ctx, &v1.ForwardRequest{Envelope: i.ev}); err != nil { + l.queue(i) + return err + } + + return nil +} + +func (l *RemoteEventsPublisher) forwardRequest(ctx context.Context, req *v1.ForwardRequest) error { + service, err := l.client.EventsService() + if err == nil { + fCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + _, err = service.Forward(fCtx, req) + cancel() + if err == nil { + return nil + } + } + + if !errors.Is(err, ttrpc.ErrClosed) { + return err + } + + // Reconnect and retry request + if err = l.client.Reconnect(); err != nil { + return err + } + + service, err = l.client.EventsService() + if err != nil { + return err + } + + // try again with a fresh context, otherwise we may get a context timeout unexpectedly. + fCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + _, err = service.Forward(fCtx, req) + cancel() + if err != nil { + return err + } + + return nil +} diff --git a/pkg/shim/shim.go b/pkg/shim/shim.go new file mode 100644 index 0000000000..668572a5c9 --- /dev/null +++ b/pkg/shim/shim.go @@ -0,0 +1,532 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "runtime/debug" + "time" + + shimapi "github.com/containerd/containerd/api/runtime/task/v3" + "github.com/containerd/containerd/api/types" + "github.com/containerd/containerd/v2/core/events" + "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/containerd/containerd/v2/pkg/protobuf" + "github.com/containerd/containerd/v2/pkg/protobuf/proto" + "github.com/containerd/containerd/v2/pkg/shutdown" + "github.com/containerd/containerd/v2/plugins" + "github.com/containerd/containerd/v2/version" + "github.com/containerd/log" + "github.com/containerd/plugin" + "github.com/containerd/plugin/registry" + "github.com/containerd/ttrpc" +) + +// Publisher for events +type Publisher interface { + events.Publisher + io.Closer +} + +// StartOpts describes shim start configuration received from containerd +type StartOpts struct { + Address string + TTRPCAddress string + Debug bool +} + +// BootstrapParams is a JSON payload returned in stdout from shim.Start call. +type BootstrapParams struct { + // Version is the version of shim parameters (expected 2 for shim v2) + Version int `json:"version"` + // Address is a address containerd should use to connect to shim. + Address string `json:"address"` + // Protocol is either TTRPC or GRPC. + Protocol string `json:"protocol"` +} + +type StopStatus struct { + Pid int + ExitStatus int + ExitedAt time.Time +} + +// Manager is the interface which manages the shim process +type Manager interface { + Name() string + Start(ctx context.Context, id string, opts StartOpts) (BootstrapParams, error) + Stop(ctx context.Context, id string) (StopStatus, error) + Info(ctx context.Context, optionsR io.Reader) (*types.RuntimeInfo, error) +} + +// OptsKey is the context key for the Opts value. +type OptsKey struct{} + +// Opts are context options associated with the shim invocation. +type Opts struct { + BundlePath string + Debug bool +} + +// BinaryOpts allows the configuration of a shims binary setup +type BinaryOpts func(*Config) + +// Config of shim binary options provided by shim implementations +type Config struct { + // NoSubreaper disables setting the shim as a child subreaper + NoSubreaper bool + // NoReaper disables the shim binary from reaping any child process implicitly + NoReaper bool + // NoSetupLogger disables automatic configuration of logrus to use the shim FIFO + NoSetupLogger bool +} + +type TTRPCService interface { + RegisterTTRPC(*ttrpc.Server) error +} + +type TTRPCServerUnaryOptioner interface { + UnaryServerInterceptor() ttrpc.UnaryServerInterceptor +} + +type TTRPCClientUnaryOptioner interface { + UnaryClientInterceptor() ttrpc.UnaryClientInterceptor +} + +var ( + debugFlag bool + versionFlag bool + infoFlag bool + id string + namespaceFlag string + socketFlag string + debugSocketFlag string + bundlePath string + addressFlag string + containerdBinaryFlag string + action string +) + +const ( + ttrpcAddressEnv = "TTRPC_ADDRESS" + grpcAddressEnv = "GRPC_ADDRESS" + namespaceEnv = "NAMESPACE" + maxVersionEnv = "MAX_SHIM_VERSION" +) + +func parseFlags() { + flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs") + flag.BoolVar(&versionFlag, "v", false, "show the shim version and exit") + // "info" is not a subcommand, because old shims produce very confusing errors for unknown subcommands + // https://github.com/containerd/containerd/pull/8509#discussion_r1210021403 + flag.BoolVar(&infoFlag, "info", false, "get the option protobuf from stdin, print the shim info protobuf to stdout, and exit") + flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") + flag.StringVar(&id, "id", "", "id of the task") + flag.StringVar(&socketFlag, "socket", "", "socket path to serve") + flag.StringVar(&debugSocketFlag, "debug-socket", "", "debug socket path to serve") + flag.StringVar(&bundlePath, "bundle", "", "path to the bundle if not workdir") + + flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") + flag.StringVar(&containerdBinaryFlag, "publish-binary", "", + fmt.Sprintf("path to publish binary (used for publishing events), but %s will ignore this flag, please use the %s env", os.Args[0], ttrpcAddressEnv), + ) + + flag.Parse() + action = flag.Arg(0) +} + +func setRuntime() { + debug.SetGCPercent(40) + go func() { + for range time.Tick(30 * time.Second) { + debug.FreeOSMemory() + } + }() + if os.Getenv("GOMAXPROCS") == "" { + // If GOMAXPROCS hasn't been set, we default to a value of 2 to reduce + // the number of Go stacks present in the shim. + runtime.GOMAXPROCS(2) + } +} + +func setLogger(ctx context.Context, id string) (context.Context, error) { + l := log.G(ctx) + _ = log.SetFormat(log.TextFormat) + if debugFlag { + _ = log.SetLevel("debug") + } + f, err := openLog(ctx, id) + if err != nil { + return ctx, err + } + l.Logger.SetOutput(f) + return log.WithLogger(ctx, l), nil +} + +// Run initializes and runs a shim server. +func Run(ctx context.Context, manager Manager, opts ...BinaryOpts) { + var config Config + for _, o := range opts { + o(&config) + } + + ctx = log.WithLogger(ctx, log.G(ctx).WithField("runtime", manager.Name())) + + if err := run(ctx, manager, config); err != nil { + fmt.Fprintf(os.Stderr, "%s: %s", manager.Name(), err) + os.Exit(1) + } +} + +func runInfo(ctx context.Context, manager Manager) error { + info, err := manager.Info(ctx, os.Stdin) + if err != nil { + return err + } + infoB, err := proto.Marshal(info) + if err != nil { + return err + } + _, err = os.Stdout.Write(infoB) + return err +} + +func run(ctx context.Context, manager Manager, config Config) error { + parseFlags() + if versionFlag { + fmt.Printf("%s:\n", filepath.Base(os.Args[0])) + fmt.Println(" Version: ", version.Version) + fmt.Println(" Revision:", version.Revision) + fmt.Println(" Go version:", version.GoVersion) + fmt.Println("") + return nil + } + + if infoFlag { + return runInfo(ctx, manager) + } + + if namespaceFlag == "" { + return fmt.Errorf("shim namespace cannot be empty") + } + + setRuntime() + + signals, err := setupSignals(config) + if err != nil { + return err + } + + if !config.NoSubreaper { + if err := subreaper(); err != nil { + return err + } + } + + ttrpcAddress := os.Getenv(ttrpcAddressEnv) + + ctx = namespaces.WithNamespace(ctx, namespaceFlag) + ctx = context.WithValue(ctx, OptsKey{}, Opts{BundlePath: bundlePath, Debug: debugFlag}) + ctx, sd := shutdown.WithShutdown(ctx) + defer sd.Shutdown() + + // Handle explicit actions + switch action { + case "delete": + logger := log.G(ctx).WithFields(log.Fields{ + "pid": os.Getpid(), + "namespace": namespaceFlag, + }) + if debugFlag { + logger.Logger.SetLevel(log.DebugLevel) + } + go func() { + _ = reap(ctx, logger, signals) + }() + ss, err := manager.Stop(ctx, id) + if err != nil { + return err + } + data, err := proto.Marshal(&shimapi.DeleteResponse{ + Pid: uint32(ss.Pid), + ExitStatus: uint32(ss.ExitStatus), + ExitedAt: protobuf.ToTimestamp(ss.ExitedAt), + }) + if err != nil { + return err + } + if _, err := os.Stdout.Write(data); err != nil { + return err + } + return nil + case "start": + opts := StartOpts{ + Address: addressFlag, + TTRPCAddress: ttrpcAddress, + Debug: debugFlag, + } + + params, err := manager.Start(ctx, id, opts) + if err != nil { + return err + } + + data, err := json.Marshal(¶ms) + if err != nil { + return fmt.Errorf("failed to marshal bootstrap params to json: %w", err) + } + + if _, err := os.Stdout.Write(data); err != nil { + return err + } + + return nil + } + + if !config.NoSetupLogger { + ctx, err = setLogger(ctx, id) + if err != nil { + return err + } + } + + registry.Register(&plugin.Registration{ + Type: plugins.InternalPlugin, + ID: "shutdown", + InitFn: func(ic *plugin.InitContext) (interface{}, error) { + return sd, nil + }, + }) + + // Register event plugin + registry.Register(&plugin.Registration{ + Type: plugins.EventPlugin, + ID: "publisher", + InitFn: func(ic *plugin.InitContext) (interface{}, error) { + return NewPublisher(ttrpcAddress, func(cfg *publisherConfig) { + p, _ := ic.GetByID(plugins.TTRPCPlugin, "otelttrpc") + if p == nil { + return + } + + opts := ttrpc.WithUnaryClientInterceptor(p.(TTRPCClientUnaryOptioner).UnaryClientInterceptor()) + WithPublishTTRPCOpts(opts)(cfg) + }) + }, + }) + + var ( + initialized = plugin.NewPluginSet() + ttrpcServices = []TTRPCService{} + + ttrpcUnaryInterceptors = []ttrpc.UnaryServerInterceptor{} + + pprofHandler server + ) + + for _, p := range registry.Graph(func(*plugin.Registration) bool { return false }) { + pID := p.URI() + log.G(ctx).WithFields(log.Fields{"id": pID, "type": p.Type}).Debug("loading plugin") + + initContext := plugin.NewContext( + ctx, + initialized, + map[string]string{ + // NOTE: Root is empty since the shim does not support persistent storage, + // shim plugins should make use state directory for writing files to disk. + // The state directory will be destroyed when the shim if cleaned up or + // on reboot + plugins.PropertyStateDir: filepath.Join(bundlePath, p.URI()), + plugins.PropertyGRPCAddress: addressFlag, + plugins.PropertyTTRPCAddress: ttrpcAddress, + }, + ) + + // load the plugin specific configuration if it is provided + // TODO: Read configuration passed into shim, or from state directory? + // if p.Config != nil { + // pc, err := config.Decode(p) + // if err != nil { + // return nil, err + // } + // initContext.Config = pc + // } + + result := p.Init(initContext) + if err := initialized.Add(result); err != nil { + return fmt.Errorf("could not add plugin result to plugin set: %w", err) + } + + instance, err := result.Instance() + if err != nil { + if plugin.IsSkipPlugin(err) { + log.G(ctx).WithFields(log.Fields{"id": pID, "type": p.Type, "error": err}).Info("skip loading plugin") + continue + } + return fmt.Errorf("failed to load plugin %s: %w", pID, err) + } + + if src, ok := instance.(TTRPCService); ok { + log.G(ctx).WithField("id", pID).Debug("registering ttrpc service") + ttrpcServices = append(ttrpcServices, src) + } + + if src, ok := instance.(TTRPCServerUnaryOptioner); ok { + ttrpcUnaryInterceptors = append(ttrpcUnaryInterceptors, src.UnaryServerInterceptor()) + } + + if result.Registration.ID == "pprof" { + if src, ok := instance.(server); ok { + pprofHandler = src + } + } + } + + if len(ttrpcServices) == 0 { + return fmt.Errorf("required that ttrpc service") + } + + unaryInterceptor := chainUnaryServerInterceptors(ttrpcUnaryInterceptors...) + server, err := newServer(ttrpc.WithUnaryServerInterceptor(unaryInterceptor)) + if err != nil { + return fmt.Errorf("failed creating server: %w", err) + } + + for _, srv := range ttrpcServices { + if err := srv.RegisterTTRPC(server); err != nil { + return fmt.Errorf("failed to register service: %w", err) + } + } + + if err := serve(ctx, server, signals, sd.Shutdown, pprofHandler); err != nil { + if !errors.Is(err, shutdown.ErrShutdown) { + cleanupSockets(ctx) + return err + } + } + + // NOTE: If the shim server is down(like oom killer), the address + // socket might be leaking. + cleanupSockets(ctx) + + select { + case <-sd.Done(): + return nil + case <-time.After(5 * time.Second): + return errors.New("shim shutdown timeout") + } +} + +// serve serves the ttrpc API over a unix socket in the current working directory +// and blocks until the context is canceled +func serve(ctx context.Context, server *ttrpc.Server, signals chan os.Signal, shutdown func(), pprof server) error { + dump := make(chan os.Signal, 32) + setupDumpStacks(dump) + + path, err := os.Getwd() + if err != nil { + return err + } + + l, err := serveListener(socketFlag, 3) + if err != nil { + return err + } + + serrs := make(chan error, 1) + defer close(serrs) + go func() { + defer l.Close() + if err := server.Serve(ctx, l); err != nil && !errors.Is(err, net.ErrClosed) { + log.G(ctx).WithError(err).Fatal("containerd-shim: ttrpc server failure") + serrs <- err + return + } + serrs <- nil + }() + + // Notify the parent process that the shim is ready. + // On Windows this signals a named event; on Unix this is a no-op. + if err = notifyReady(ctx, serrs); err != nil { + return err + } + + if debugFlag && pprof != nil { + if err := setupPprof(ctx, pprof); err != nil { + log.G(ctx).WithError(err).Warn("Could not setup pprof") + } + } + + logger := log.G(ctx).WithFields(log.Fields{ + "pid": os.Getpid(), + "path": path, + "namespace": namespaceFlag, + }) + go func() { + for range dump { + dumpStacks(logger) + } + }() + + go handleExitSignals(ctx, logger, shutdown) + return reap(ctx, logger, signals) +} + +func dumpStacks(logger *log.Entry) { + var ( + buf []byte + stackSize int + ) + bufferLen := 16384 + for stackSize == len(buf) { + buf = make([]byte, bufferLen) + stackSize = runtime.Stack(buf, true) + bufferLen *= 2 + } + buf = buf[:stackSize] + logger.Infof("=== BEGIN goroutine stack dump ===\n%s\n=== END goroutine stack dump ===", buf) +} + +type server interface { + Serve(net.Listener) error +} + +func setupPprof(ctx context.Context, srv server) error { + l, err := serveListener(debugSocketFlag, 4) + if err != nil { + return fmt.Errorf("could not setup pprof listener: %w", err) + } + + go func() { + if err := srv.Serve(l); err != nil && !errors.Is(err, net.ErrClosed) { + log.G(ctx).WithError(err).Fatal("containerd-shim: pprof endpoint failure") + } + }() + + return nil +} diff --git a/pkg/shim/shim_test.go b/pkg/shim/shim_test.go new file mode 100644 index 0000000000..46bc68405a --- /dev/null +++ b/pkg/shim/shim_test.go @@ -0,0 +1,64 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "runtime" + "testing" +) + +func TestRuntimeWithEmptyMaxEnvProcs(t *testing.T) { + var oldGoMaxProcs = runtime.GOMAXPROCS(0) + defer runtime.GOMAXPROCS(oldGoMaxProcs) + + t.Setenv("GOMAXPROCS", "") + setRuntime() + + var currentGoMaxProcs = runtime.GOMAXPROCS(0) + if currentGoMaxProcs != 2 { + t.Fatal("the max number of procs should be 2") + } +} + +func TestRuntimeWithNonEmptyMaxEnvProcs(t *testing.T) { + t.Setenv("GOMAXPROCS", "not_empty") + setRuntime() + var oldGoMaxProcs2 = runtime.GOMAXPROCS(0) + if oldGoMaxProcs2 != runtime.NumCPU() { + t.Fatal("the max number CPU should be equal to available CPUs") + } +} + +func TestShimOptWithValue(t *testing.T) { + ctx := context.TODO() + ctx = context.WithValue(ctx, OptsKey{}, Opts{Debug: true}) + + o := ctx.Value(OptsKey{}) + if o == nil { + t.Fatal("opts nil") + } + op, ok := o.(Opts) + if !ok { + t.Fatal("opts not of type Opts") + } + if !op.Debug { + t.Fatal("opts.Debug should be true") + } +} diff --git a/pkg/shim/shim_windows.go b/pkg/shim/shim_windows.go new file mode 100644 index 0000000000..4b832b984b --- /dev/null +++ b/pkg/shim/shim_windows.go @@ -0,0 +1,253 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "fmt" + "io" + "net" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + winio "github.com/Microsoft/go-winio" + "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/containerd/log" + "github.com/containerd/ttrpc" + "golang.org/x/sys/windows" +) + +// notifyReady signals the parent process that the shim server is ready. +// On Windows, this closes stdout and sets a named event that the parent +// process is waiting on to know the shim has started successfully. +func notifyReady(_ context.Context, serrs chan error) error { + select { + case err := <-serrs: + return err + case <-time.After(2 * time.Millisecond): + // This is our best indication that we have not errored on creation + // and are successfully serving the API. + os.Stdout.Close() + eventName, _ := windows.UTF16PtrFromString(fmt.Sprintf("%s-%s", namespaceFlag, id)) + // Open the existing event and set it to wake up the parent process which is waiting for the shim to be ready. + handle, err := windows.OpenEvent(windows.EVENT_MODIFY_STATE, false, eventName) + if err == nil { + _ = windows.SetEvent(handle) // Wake up the parent + _ = windows.CloseHandle(handle) // Clean up + } + } + return nil +} + +// setupSignals creates a signal channel for Windows. +// On Windows, we don't register any signals here because: +// 1. Child process reaping (SIGCHLD) is not needed - the OS handles it. +// 2. Exit signals (SIGINT/SIGTERM) are handled by handleExitSignals separately. +// We return an empty channel that reap() can use, but it won't receive signals. +func setupSignals(_ Config) (chan os.Signal, error) { + signals := make(chan os.Signal, 32) + return signals, nil +} + +// newServer creates a new ttrpc server for Windows. +// Unlike Unix, Windows doesn't have user-based socket authentication, +// so we create a basic ttrpc server without the handshaker. +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { + return ttrpc.NewServer(opts...) +} + +// subreaper is not applicable on Windows as the OS automatically +// handles orphaned processes differently than Unix systems. +func subreaper() error { + // This is a no-op on Windows - the OS handles orphaned processes + return nil +} + +// setupDumpStacks is currently not implemented for Windows. +// Windows doesn't have SIGUSR1, so stack dumping would need to use +// a different mechanism (e.g., a named event or debug console). +func setupDumpStacks(_ chan<- os.Signal) { + // No-op on Windows - SIGUSR1 doesn't exist + // Future: could implement using Windows events or console signals +} + +// serveListener creates a named pipe listener for Windows. +// If path is provided, it creates a new named pipe at that location. +// If path is empty and fd is provided, it attempts to inherit the listener (not commonly used on Windows). +func serveListener(path string, _ uintptr) (net.Listener, error) { + if path == "" { + // On Windows, inheriting file descriptors is more complex and rarely used + // with named pipes. We'll return an error if no path is provided. + return nil, fmt.Errorf("named pipe path is required on Windows") + } + + // Ensure the path is in the correct Windows named pipe format + // Expected format: \\.\pipe\ + if !strings.HasPrefix(path, `\\.\pipe`) { + return nil, fmt.Errorf("socket is required to be pipe address") + } + + l, err := winio.ListenPipe(path, nil) + if err != nil { + return nil, fmt.Errorf("failed to create named pipe listener at %s: %w", path, err) + } + + log.L.WithField("pipe", path).Debug("serving api on named pipe") + return l, nil +} + +// reap handles signals on Windows. Unlike Unix, Windows doesn't send SIGCHLD +// when child processes exit, so we only need to handle shutdown signals. +func reap(ctx context.Context, logger *log.Entry, signals chan os.Signal) error { + logger.Debug("starting signal loop") + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case s := <-signals: + logger.WithField("signal", s).Debug("received signal in reap loop") + // On Windows, we just log the signal + // Exit signals are handled in handleExitSignals + } + } +} + +// handleExitSignals listens for shutdown signals (SIGINT, SIGTERM) and +// triggers the provided cancel function for graceful shutdown. +func handleExitSignals(ctx context.Context, logger *log.Entry, cancel context.CancelFunc) { + ch := make(chan os.Signal, 32) + // On Windows, os.Kill cannot be caught. We handle os.Interrupt (Ctrl+C) and SIGTERM. + signal.Notify(ch, os.Interrupt, syscall.SIGTERM) + + for { + select { + case s := <-ch: + logger.WithField("signal", s).Debug("caught exit signal") + cancel() + return + case <-ctx.Done(): + return + } + } +} + +// openLog creates a named pipe for shim logging on Windows. +// The containerd daemon connects to this pipe as a client to read log output. +// The pipe format is: \\.\pipe\containerd-shim-{namespace}-{id}-log +func openLog(ctx context.Context, id string) (io.Writer, error) { + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + pipePath := fmt.Sprintf("\\\\.\\pipe\\containerd-shim-%s-%s-log", ns, id) + l, err := winio.ListenPipe(pipePath, nil) + if err != nil { + return nil, fmt.Errorf("failed to create shim log pipe: %w", err) + } + + rlw := &reconnectingLogWriter{ + l: l, + } + + // Accept connections from containerd in the background. + // Supports reconnection if containerd restarts. + go rlw.acceptConnections() + + return rlw, nil +} + +// reconnectingLogWriter is a writer that accepts log connections from containerd. +// It supports reconnection - if containerd restarts, a new connection is accepted +// and the old one is closed. Logs generated during reconnection may be lost. +type reconnectingLogWriter struct { + l net.Listener // The named pipe listener waiting for connections + mu sync.Mutex // Protects the current connection + conn net.Conn // The current active connection (may be nil) +} + +// acceptConnections listens for log connections in the background. +func (rlw *reconnectingLogWriter) acceptConnections() { + for { + newConn, err := rlw.l.Accept() + if err != nil { + // Listener was closed, stop accepting + return + } + + rlw.mu.Lock() + // Close the old connection if one exists + if rlw.conn != nil { + rlw.conn.Close() + } + rlw.conn = newConn + rlw.mu.Unlock() + } +} + +// Write implements io.Writer. It writes to the current connection if one exists. +// If no connection is established yet, writes are silently dropped to avoid +// blocking the shim. +func (rlw *reconnectingLogWriter) Write(p []byte) (n int, err error) { + rlw.mu.Lock() + conn := rlw.conn + rlw.mu.Unlock() + + if conn == nil { + // No connection yet, drop the log. + return len(p), nil + } + + n, err = conn.Write(p) + if err != nil { + // Connection may have been closed, clear it so next write + // doesn't try to use a broken connection + rlw.mu.Lock() + if rlw.conn == conn { + rlw.conn.Close() + rlw.conn = nil + } + rlw.mu.Unlock() + // Return success anyway to avoid log write errors propagating + return len(p), nil + } + return n, nil +} + +// Close implements io.Closer. It closes both the listener and any active connection. +func (rlw *reconnectingLogWriter) Close() error { + rlw.mu.Lock() + defer rlw.mu.Unlock() + + var err error + if rlw.l != nil { + err = rlw.l.Close() + } + if rlw.conn != nil { + if cerr := rlw.conn.Close(); cerr != nil && err == nil { + err = cerr + } + rlw.conn = nil + } + return err +} diff --git a/pkg/shim/shim_windows_test.go b/pkg/shim/shim_windows_test.go new file mode 100644 index 0000000000..7d6f62eff9 --- /dev/null +++ b/pkg/shim/shim_windows_test.go @@ -0,0 +1,476 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "fmt" + "net" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + winio "github.com/Microsoft/go-winio" + "github.com/containerd/containerd/v2/pkg/namespaces" +) + +const ( + // connectionWaitTime is the time to wait for pipe connections to be established + connectionWaitTime = 50 * time.Millisecond + // readTimeout is the timeout for reading from pipe connections + readTimeout = time.Second +) + +// testPipeCounter ensures unique pipe names across parallel tests +var testPipeCounter atomic.Uint64 + +// uniquePipePath generates a unique pipe path for testing +func uniquePipePath(prefix string) string { + return fmt.Sprintf(`\\.\pipe\%s-%d-%d`, prefix, os.Getpid(), testPipeCounter.Add(1)) +} + +// createTestPipe creates a named pipe listener for testing and returns cleanup function +func createTestPipe(t *testing.T, pipePath string) net.Listener { + t.Helper() + l, err := winio.ListenPipe(pipePath, nil) + if err != nil { + t.Fatalf("failed to create test pipe: %v", err) + } + return l +} + +// connectToPipe connects to a pipe and returns cleanup function +func connectToPipe(t *testing.T, pipePath string) net.Conn { + t.Helper() + conn, err := winio.DialPipe(pipePath, nil) + if err != nil { + t.Fatalf("failed to connect to pipe: %v", err) + } + return conn +} + +// readResult holds the result of an async read operation +type readResult struct { + buf []byte + err error +} + +// asyncRead reads data from connection with timeout in a goroutine. +// Returns a channel that will receive the result when the read completes. +func asyncRead(conn net.Conn, expectedLen int) <-chan readResult { + resultChan := make(chan readResult, 1) + go func() { + buf := make([]byte, expectedLen) + _ = conn.SetReadDeadline(time.Now().Add(readTimeout)) + nRead, err := conn.Read(buf) + resultChan <- readResult{buf: buf[:nRead], err: err} + }() + return resultChan +} + +func TestSetupSignals(t *testing.T) { + tests := []struct { + name string + config Config + expectError bool + expectNilChan bool + expectedCapacity int + }{ + { + name: "default config creates signal channel with capacity 32", + config: Config{}, + expectError: false, + expectNilChan: false, + expectedCapacity: 32, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + signals, err := setupSignals(tt.config) + if (err != nil) != tt.expectError { + t.Fatalf("setupSignals() error = %v, expectError %v", err, tt.expectError) + } + if (signals == nil) != tt.expectNilChan { + t.Fatal("setupSignals returned unexpected nil channel state") + } + if signals != nil && cap(signals) != tt.expectedCapacity { + t.Fatalf("expected signal channel capacity %d, got %d", tt.expectedCapacity, cap(signals)) + } + }) + } +} + +func TestServeListener(t *testing.T) { + tests := []struct { + name string + path string + expectError bool + shouldClose bool + }{ + { + name: "empty path should fail", + path: "", + expectError: true, + shouldClose: false, + }, + { + name: "non-pipe path should fail", + path: "/tmp/invalid/path", + expectError: true, + shouldClose: false, + }, + { + name: "valid pipe path should succeed", + path: fmt.Sprintf(`\\.\pipe\containerd-shim-test-%d`, os.Getpid()), + expectError: false, + shouldClose: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l, err := serveListener(tt.path, 0) + if (err != nil) != tt.expectError { + t.Fatalf("serveListener() error = %v, expectError %v", err, tt.expectError) + } + if tt.shouldClose && l != nil { + defer l.Close() + } + if !tt.expectError && l == nil { + t.Fatal("serveListener returned nil listener") + } + }) + } +} + +func TestReconnectingLogWriterDropsLogsBeforeConnection(t *testing.T) { + t.Parallel() + pipePath := uniquePipePath("shim-test-log") + l := createTestPipe(t, pipePath) + + rlw := &reconnectingLogWriter{l: l} + go rlw.acceptConnections() + defer rlw.Close() + + // Write before any client connects - should not block and return success + testData := []byte("test log message before connection") + n, err := rlw.Write(testData) + if err != nil { + t.Fatalf("Write should not return error before connection: %v", err) + } + if n != len(testData) { + t.Fatalf("Write should return len(data) even when dropping: got %d, want %d", n, len(testData)) + } +} + +func TestReconnectingLogWriterWritesAfterConnection(t *testing.T) { + t.Parallel() + pipePath := uniquePipePath("shim-test-log-write") + l := createTestPipe(t, pipePath) + + rlw := &reconnectingLogWriter{l: l} + go rlw.acceptConnections() + defer rlw.Close() + + // Connect a client + clientConn := connectToPipe(t, pipePath) + defer clientConn.Close() + + // Give time for connection to be accepted + time.Sleep(connectionWaitTime) + + // Write after client connects + testData := []byte("test log message after connection") + + // Start reading from client side before writing to prevent blocking + readChan := asyncRead(clientConn, len(testData)) + + n, err := rlw.Write(testData) + if err != nil { + t.Fatalf("Write failed after connection: %v", err) + } + if n != len(testData) { + t.Fatalf("Write returned wrong length: got %d, want %d", n, len(testData)) + } + + // Wait for read to complete and verify + result := <-readChan + if result.err != nil { + t.Fatalf("client failed to read: %v", result.err) + } + if string(result.buf) != string(testData) { + t.Fatalf("client read wrong data: got %q, want %q", string(result.buf), string(testData)) + } +} + +func TestReconnectingLogWriterSupportsReconnection(t *testing.T) { + t.Parallel() + pipePath := uniquePipePath("shim-test-log-reconnect") + l := createTestPipe(t, pipePath) + + rlw := &reconnectingLogWriter{l: l} + go rlw.acceptConnections() + defer rlw.Close() + + // First client connects + client1 := connectToPipe(t, pipePath) + + // Give time for connection to be accepted + time.Sleep(connectionWaitTime) + + // Write with first client + testData1 := []byte("message to first client") + + // Start reading from first client before writing + readChan1 := asyncRead(client1, len(testData1)) + + _, err := rlw.Write(testData1) + if err != nil { + t.Fatalf("Write to first client failed: %v", err) + } + + // Wait for read to complete and verify + result1 := <-readChan1 + if result1.err != nil { + t.Fatalf("first client failed to read: %v", result1.err) + } + if string(result1.buf) != string(testData1) { + t.Fatalf("first client read wrong data: got %q, want %q", string(result1.buf), string(testData1)) + } + + // Second client connects (simulating containerd restart) + client2 := connectToPipe(t, pipePath) + defer client2.Close() + + // Give time for new connection to be accepted and old one closed + time.Sleep(connectionWaitTime) + + // Close first client (it should already be closed by the writer) + client1.Close() + + // Write with second client connected + testData2 := []byte("message to second client") + + // Start reading from second client before writing + readChan2 := asyncRead(client2, len(testData2)) + + _, err = rlw.Write(testData2) + if err != nil { + t.Fatalf("Write to second client failed: %v", err) + } + + // Wait for read to complete and verify + result2 := <-readChan2 + if result2.err != nil { + t.Fatalf("second client failed to read: %v", result2.err) + } + if string(result2.buf) != string(testData2) { + t.Fatalf("second client read wrong data: got %q, want %q", string(result2.buf), string(testData2)) + } +} + +func TestReconnectingLogWriterClose(t *testing.T) { + t.Parallel() + pipePath := uniquePipePath("shim-test-log-close") + l := createTestPipe(t, pipePath) + + rlw := &reconnectingLogWriter{l: l} + go rlw.acceptConnections() + + // Connect a client + client := connectToPipe(t, pipePath) + defer client.Close() + + // Give time for connection to be accepted + time.Sleep(connectionWaitTime) + + // Close the writer + err := rlw.Close() + if err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Verify listener is closed by trying to connect again + _, err = winio.DialPipe(pipePath, nil) + if err == nil { + t.Fatal("should not be able to connect after Close") + } +} + +func TestReconnectingLogWriterConcurrentWrites(t *testing.T) { + t.Parallel() + pipePath := uniquePipePath("shim-test-log-concurrent") + l := createTestPipe(t, pipePath) + + rlw := &reconnectingLogWriter{l: l} + go rlw.acceptConnections() + defer rlw.Close() + + // Connect a client + client := connectToPipe(t, pipePath) + defer client.Close() + + // Give time for connection to be accepted + time.Sleep(connectionWaitTime) + + // Start reading in background + readDone := make(chan struct{}) + go func() { + defer close(readDone) + buf := make([]byte, 4096) + for { + _ = client.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _, err := client.Read(buf) + if err != nil { + return + } + } + }() + + // Concurrent writes - collect errors instead of using t.Errorf in goroutine + const numWriters = 10 + errChan := make(chan error, numWriters) + var wg sync.WaitGroup + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + msg := fmt.Sprintf("concurrent message %d\n", id) + _, err := rlw.Write([]byte(msg)) + if err != nil { + errChan <- fmt.Errorf("concurrent Write %d failed: %w", id, err) + } + }(i) + } + wg.Wait() + close(errChan) + + // Report any errors from concurrent writes + for err := range errChan { + t.Error(err) + } + + // Close and wait for reader to finish + rlw.Close() + <-readDone +} + +func TestOpenLog(t *testing.T) { + tests := []struct { + name string + setupCtx func() context.Context + containerID string + expectError bool + shouldConnect bool + pipePath string + }{ + { + name: "creates named pipe and accepts connections", + setupCtx: func() context.Context { + return namespaces.WithNamespace(context.Background(), "test-ns") + }, + containerID: "test-container-id", + expectError: false, + shouldConnect: true, + pipePath: `\\.\pipe\containerd-shim-test-ns-test-container-id-log`, + }, + { + name: "fails without namespace in context", + setupCtx: func() context.Context { + return context.Background() + }, + containerID: "test-container-id", + expectError: true, + shouldConnect: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setupCtx() + + writer, err := openLog(ctx, tt.containerID) + if (err != nil) != tt.expectError { + t.Fatalf("openLog() error = %v, expectError %v", err, tt.expectError) + } + + if tt.expectError { + return + } + + defer writer.(interface{ Close() error }).Close() + + if tt.shouldConnect { + // Verify we can connect to the created pipe + client, err := winio.DialPipe(tt.pipePath, nil) + if err != nil { + t.Fatalf("failed to connect to log pipe: %v", err) + } + defer client.Close() + + // Give time for connection to be accepted + time.Sleep(connectionWaitTime) + + // Write should succeed + testData := []byte("test log from openLog") + + // Read from client in goroutine to prevent blocking + readDone := make(chan struct{}) + var readBuf []byte + var readErr error + go func() { + defer close(readDone) + buf := make([]byte, len(testData)) + _ = client.SetReadDeadline(time.Now().Add(readTimeout)) + nRead, err := client.Read(buf) + readBuf = buf[:nRead] + readErr = err + }() + + n, err := writer.Write(testData) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + if n != len(testData) { + t.Fatalf("Write returned wrong length: got %d, want %d", n, len(testData)) + } + + // Wait for read to complete and verify client receives the data + <-readDone + if readErr != nil { + t.Fatalf("client failed to read: %v", readErr) + } + if string(readBuf) != string(testData) { + t.Fatalf("client read wrong data: got %q, want %q", string(readBuf), string(testData)) + } + } + }) + } +} + +func TestSubreaper(t *testing.T) { + // On Windows, subreaper is a no-op + err := subreaper() + if err != nil { + t.Fatalf("subreaper should return nil on Windows: %v", err) + } +} diff --git a/pkg/shim/util.go b/pkg/shim/util.go new file mode 100644 index 0000000000..06e0135af3 --- /dev/null +++ b/pkg/shim/util.go @@ -0,0 +1,220 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/containerd/ttrpc" + "github.com/containerd/typeurl/v2" + + "github.com/containerd/containerd/v2/pkg/atomicfile" + "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/containerd/containerd/v2/pkg/protobuf/proto" + "github.com/containerd/containerd/v2/pkg/protobuf/types" + "github.com/containerd/errdefs" +) + +type CommandConfig struct { + Runtime string + Address string + TTRPCAddress string + Path string + Args []string + Opts *types.Any + Env []string +} + +// Command returns the shim command with the provided args and configuration +func Command(ctx context.Context, config *CommandConfig) (*exec.Cmd, error) { + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + self, err := os.Executable() + if err != nil { + return nil, err + } + args := []string{ + "-namespace", ns, + "-address", config.Address, + "-publish-binary", self, + } + args = append(args, config.Args...) + cmd := exec.CommandContext(ctx, config.Runtime, args...) + cmd.Dir = config.Path + cmd.Env = append( + os.Environ(), + "GOMAXPROCS=2", + fmt.Sprintf("%s=2", maxVersionEnv), + fmt.Sprintf("%s=%s", ttrpcAddressEnv, config.TTRPCAddress), + fmt.Sprintf("%s=%s", grpcAddressEnv, config.Address), + fmt.Sprintf("%s=%s", namespaceEnv, ns), + ) + if len(config.Env) > 0 { + cmd.Env = append(cmd.Env, config.Env...) + } + cmd.SysProcAttr = getSysProcAttr() + if config.Opts != nil { + d, err := proto.Marshal(config.Opts) + if err != nil { + return nil, err + } + cmd.Stdin = bytes.NewReader(d) + } + return cmd, nil +} + +// BinaryName returns the shim binary name from the runtime name, +// empty string returns means runtime name is invalid +func BinaryName(runtime string) string { + // runtime name should format like $prefix.name.version + parts := strings.Split(runtime, ".") + if len(parts) < 2 || parts[0] == "" { + return "" + } + + return fmt.Sprintf(shimBinaryFormat, parts[len(parts)-2], parts[len(parts)-1]) +} + +// BinaryPath returns the full path for the shim binary from the runtime name, +// empty string returns means runtime name is invalid +func BinaryPath(runtime string) string { + dir := filepath.Dir(runtime) + binary := BinaryName(runtime) + + path, err := filepath.Abs(filepath.Join(dir, binary)) + if err != nil { + return "" + } + + return path +} + +// Connect to the provided address +func Connect(address string, d func(string, time.Duration) (net.Conn, error)) (net.Conn, error) { + return d(address, 100*time.Second) +} + +// WritePidFile writes a pid file atomically +func WritePidFile(path string, pid int) error { + path, err := filepath.Abs(path) + if err != nil { + return err + } + f, err := atomicfile.New(path, 0o644) + if err != nil { + return err + } + _, err = fmt.Fprintf(f, "%d", pid) + if err != nil { + _ = f.Cancel() + return err + } + return f.Close() +} + +// ErrNoAddress is returned when the address file has no content +var ErrNoAddress = errors.New("no shim address") + +// ReadAddress returns the shim's socket address from the path +func ReadAddress(path string) (string, error) { + path, err := filepath.Abs(path) + if err != nil { + return "", err + } + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + if len(data) == 0 { + return "", ErrNoAddress + } + return string(data), nil +} + +// ReadRuntimeOptions reads config bytes from io.Reader and unmarshals it into the provided type. +// The type must be registered with typeurl. +// +// The function will return ErrNotFound, if the config is not provided. +// And ErrInvalidArgument, if unable to cast the config to the provided type T. +func ReadRuntimeOptions[T any](reader io.Reader) (T, error) { + var config T + + data, err := io.ReadAll(reader) + if err != nil { + return config, fmt.Errorf("failed to read config bytes from stdin: %w", err) + } + + if len(data) == 0 { + return config, errdefs.ErrNotFound + } + + var any types.Any + if err := proto.Unmarshal(data, &any); err != nil { + return config, err + } + + v, err := typeurl.UnmarshalAny(&any) + if err != nil { + return config, err + } + + config, ok := v.(T) + if !ok { + return config, fmt.Errorf("invalid type %T: %w", v, errdefs.ErrInvalidArgument) + } + + return config, nil +} + +// chainUnaryServerInterceptors creates a single ttrpc server interceptor from +// a chain of many interceptors executed from first to last. +func chainUnaryServerInterceptors(interceptors ...ttrpc.UnaryServerInterceptor) ttrpc.UnaryServerInterceptor { + n := len(interceptors) + + // force to use default interceptor in ttrpc + if n == 0 { + return nil + } + + return func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + currentMethod := method + + for i := n - 1; i > 0; i-- { + interceptor := interceptors[i] + innerMethod := currentMethod + + currentMethod = func(currentCtx context.Context, currentUnmarshal func(interface{}) error) (interface{}, error) { + return interceptor(currentCtx, currentUnmarshal, info, innerMethod) + } + } + return interceptors[0](ctx, unmarshal, info, currentMethod) + } +} diff --git a/pkg/shim/util_test.go b/pkg/shim/util_test.go new file mode 100644 index 0000000000..d68676b311 --- /dev/null +++ b/pkg/shim/util_test.go @@ -0,0 +1,120 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "path/filepath" + "reflect" + "testing" + + "github.com/containerd/ttrpc" +) + +func TestChainUnaryServerInterceptors(t *testing.T) { + methodInfo := &ttrpc.UnaryServerInfo{ + FullMethod: filepath.Join("/", t.Name(), "foo"), + } + + type callKey struct{} + callValue := "init" + callCtx := context.WithValue(context.Background(), callKey{}, callValue) + + verifyCallCtxFn := func(ctx context.Context, key interface{}, expected interface{}) { + got := ctx.Value(key) + if !reflect.DeepEqual(expected, got) { + t.Fatalf("[context(key:%s) expected %v, but got %v", key, expected, got) + } + } + + verifyInfoFn := func(info *ttrpc.UnaryServerInfo) { + if !reflect.DeepEqual(methodInfo, info) { + t.Fatalf("[info] expected %+v, but got %+v", methodInfo, info) + } + } + + origUnmarshaler := func(obj interface{}) error { + v := obj.(*int64) + *v *= 2 + return nil + } + + type firstKey struct{} + firstValue := "from first" + var firstUnmarshaler ttrpc.Unmarshaler + first := func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyInfoFn(info) + + ctx = context.WithValue(ctx, firstKey{}, firstValue) + + firstUnmarshaler = func(obj interface{}) error { + if err := unmarshal(obj); err != nil { + return err + } + + v := obj.(*int64) + *v *= 2 + return nil + } + + return method(ctx, firstUnmarshaler) + } + + type secondKey struct{} + secondValue := "from second" + second := func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyCallCtxFn(ctx, firstKey{}, firstValue) + verifyInfoFn(info) + + v := int64(3) // should return 12 + if err := unmarshal(&v); err != nil { + t.Fatalf("unexpected error %v", err) + } + if expected := int64(12); v != expected { + t.Fatalf("expected int64(%v), but got %v", expected, v) + } + + ctx = context.WithValue(ctx, secondKey{}, secondValue) + return method(ctx, unmarshal) + } + + methodFn := func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyCallCtxFn(ctx, firstKey{}, firstValue) + verifyCallCtxFn(ctx, secondKey{}, secondValue) + + v := int64(2) + if err := unmarshal(&v); err != nil { + return nil, err + } + return v, nil + } + + interceptor := chainUnaryServerInterceptors(first, second) + v, err := interceptor(callCtx, origUnmarshaler, methodInfo, methodFn) + if err != nil { + t.Fatalf("expected nil, but got %v", err) + } + + if expected := int64(8); v != expected { + t.Fatalf("expected result is int64(%v), but got %v", expected, v) + } +} diff --git a/pkg/shim/util_windows.go b/pkg/shim/util_windows.go new file mode 100644 index 0000000000..268744a284 --- /dev/null +++ b/pkg/shim/util_windows.go @@ -0,0 +1,99 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "fmt" + "net" + "os" + "strings" + "syscall" + "time" + + winio "github.com/Microsoft/go-winio" + "github.com/pkg/errors" +) + +const shimBinaryFormat = "containerd-shim-%s-%s.exe" + +func getSysProcAttr() *syscall.SysProcAttr { + return &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, + } +} + +// AnonReconnectDialer returns a dialer for an existing npipe on containerd reconnection +func AnonReconnectDialer(address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if !strings.HasPrefix(address, `\\.\pipe`) { + return nil, fmt.Errorf("invalid pipe address: %s", address) + } + + c, err := winio.DialPipeContext(ctx, address) + if os.IsNotExist(err) { + return nil, fmt.Errorf("npipe not found on reconnect: %w", os.ErrNotExist) + } else if errors.Is(err, context.DeadlineExceeded) { + return nil, fmt.Errorf("timed out waiting for npipe %s: %w", address, err) + } else if err != nil { + return nil, err + } + return c, nil +} + +// AnonDialer returns a dialer for a npipe +func AnonDialer(address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if !strings.HasPrefix(address, `\\.\pipe`) { + return nil, fmt.Errorf("invalid pipe address: %s", address) + } + + // If there is nobody serving the pipe we limit the timeout for this case to + // 5 seconds because any shim that would serve this endpoint should serve it + // within 5 seconds. + serveTimer := time.NewTimer(5 * time.Second) + defer serveTimer.Stop() + for { + c, err := winio.DialPipeContext(ctx, address) + if err != nil { + if os.IsNotExist(err) { + select { + case <-serveTimer.C: + return nil, fmt.Errorf("pipe not found before timeout: %w", os.ErrNotExist) + default: + // Wait 10ms for the shim to serve and try again. + time.Sleep(10 * time.Millisecond) + continue + } + } else if errors.Is(err, context.DeadlineExceeded) { + return nil, fmt.Errorf("timed out waiting for npipe %s: %w", address, err) + } + return nil, err + } + return c, nil + } +} + +func cleanupSockets(_ context.Context) { + // On Windows, named pipes are automatically cleaned up when closed. +}