From 0f4fc8fbab01f95503fda9d76197fc75357f41db Mon Sep 17 00:00:00 2001 From: "p.zindyaev" Date: Thu, 6 Nov 2025 15:40:15 +0300 Subject: [PATCH] feature(systemd): added the ability to start a service with the predefined config --- .gitignore | 1 + README.MD | 13 +- cmd/ad-runtime-utils/app.go | 63 ++++++++++ configs/config.yaml | 2 +- examples/config/kafka.yaml | 17 +++ examples/systemd/kafka.service | 10 ++ go.mod | 5 +- go.sum | 2 + internal/config/config.go | 120 ++++++++++++++---- internal/exec/check.go | 115 +++++++++++++++++ internal/exec/exec.go | 32 +++++ internal/exec/socket.go | 219 +++++++++++++++++++++++++++++++++ internal/exec/socket_test.go | 43 +++++++ 13 files changed, 613 insertions(+), 29 deletions(-) create mode 100644 examples/config/kafka.yaml create mode 100644 examples/systemd/kafka.service create mode 100644 internal/exec/check.go create mode 100644 internal/exec/exec.go create mode 100644 internal/exec/socket.go create mode 100644 internal/exec/socket_test.go diff --git a/.gitignore b/.gitignore index 69dcb52..f6ccf46 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ # Added by goreleaser init: dist/ +cmd/ad-runtime-utils/ad-runtime-utils diff --git a/README.MD b/README.MD index aa82650..2bd1629 100644 --- a/README.MD +++ b/README.MD @@ -91,4 +91,15 @@ Default runtimes: Service : : "> … -``` \ No newline at end of file +``` + +### 4. Starting a Service + +When --start and/or --supervise flags are provided --service is a required argument. + +- It starts the executable `services..executable` with the specified arguments from `services..executable_args`. + +- If the `--supervise` flag is provided, it will run the health checks, defined in `services..health_checks` on service start to make sure the service is operational. If any of the checks fail the service will be stopped. `Type=notify` should be used in the systemd unit(see example in `examples/systemd` directory). + +- While using `--supervise` and health checks, make sure that systemd service has enough `TimeoutStartSec`. ideally should be a combined timeout of all health checks. + diff --git a/cmd/ad-runtime-utils/app.go b/cmd/ad-runtime-utils/app.go index d6ba747..a512ef3 100644 --- a/cmd/ad-runtime-utils/app.go +++ b/cmd/ad-runtime-utils/app.go @@ -4,10 +4,13 @@ import ( "flag" "fmt" "io" + "os" "strings" "github.com/arenadata/ad-runtime-utils/internal/config" "github.com/arenadata/ad-runtime-utils/internal/detect" + "github.com/arenadata/ad-runtime-utils/internal/exec" + "github.com/coreos/go-systemd/v22/daemon" ) // exit codes. @@ -27,6 +30,8 @@ func Run(args []string, stdout, stderr io.Writer) int { listAll := fs.Bool("list", false, "List all detected runtimes (default + services)") fs.BoolVar(listAll, "l", false, "shorthand for --list") printCACerts := fs.Bool("print-cacerts", false, "When used with --runtime=java, prints the cacerts path and exits") + start := fs.Bool("start", false, "Start the service. Use with simple/exec services") + supervise := fs.Bool("supervise", false, "Supervise the service. Use with notify systemd services") if err := fs.Parse(args); err != nil { return exitParseError @@ -76,6 +81,15 @@ func Run(args []string, stdout, stderr io.Writer) int { } envName := detectEnvName(cfg, *service, *runtime) + + if *start { + if err = startService(*service, envName, path, *cfg, *supervise); err != nil { + fmt.Fprintf(stderr, "start service failed: %v\n", err) + return exitUserError + } + return exitOK + } + fmt.Fprintf(stdout, "export %s=%s\n", envName, path) return exitOK } @@ -124,3 +138,52 @@ func runList(cfg *config.Config, stdout, stderr io.Writer) int { } return exitOK } + +func startService(service string, envName string, envPath string, cfg config.Config, supervise bool) error { + srvConfig, ok := cfg.Services[service] + if !ok { + return fmt.Errorf("service %s not found in config", service) + } + // Append the env for the runtime (eg. JAVA_HOME) + if srvConfig.EnvVars == nil { + srvConfig.EnvVars = make(map[string]string) + } + srvConfig.EnvVars[envName] = envPath + if !supervise { + return exec.RunExecutable(srvConfig.Executable, srvConfig.ExecutableArgs, srvConfig.EnvVars) + } + process, err := exec.RunExecutableAsync(srvConfig.Executable, srvConfig.ExecutableArgs, srvConfig.EnvVars) + if err != nil { + return err + } + // Run the health checks + for _, checkCfg := range srvConfig.HealthChecks { + switch checkCfg.Type { + case exec.PortHealthCheckType: + portheck := exec.PortHealthCheck{ + PID: process.Process.Pid, + Config: checkCfg, + } + if err = portheck.Check(); err != nil { + if err = process.Process.Signal(os.Interrupt); err != nil { + return fmt.Errorf("failed to send interrupt signal to process: %w", err) + } + return fmt.Errorf("health check failed: %w", err) + } + default: + if err = process.Process.Signal(os.Interrupt); err != nil { + return fmt.Errorf("failed to send interrupt signal to process: %w", err) + } + return fmt.Errorf("unknown health check type: %s", checkCfg.Type) + } + } + // Notify systemd daemon that service has started + if _, err = daemon.SdNotify(false, daemon.SdNotifyReady); err != nil { + fmt.Fprintf(os.Stderr, "systemd notification failed: %v\n", err) + } + // TODO: Replace this with an actual supervisor loop + if err = process.Wait(); err != nil { + return err + } + return nil +} diff --git a/configs/config.yaml b/configs/config.yaml index 0641ea1..0121bd6 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -91,4 +91,4 @@ services: FLINK2: path: /etc/flink2/conf/flink2-java.yaml OZONE: - path: /etc/ozone/conf/ozone-java.yaml \ No newline at end of file + path: /etc/ozone/conf/ozone-java.yaml diff --git a/examples/config/kafka.yaml b/examples/config/kafka.yaml new file mode 100644 index 0000000..65cb6d5 --- /dev/null +++ b/examples/config/kafka.yaml @@ -0,0 +1,17 @@ +executable: bin/kafka-server-start.sh +executable_args: + - config/server.properties +runtimes: + java: + version: "21" +health_checks: + - type: port + params: + port: 9092 + timeout: 20 + protocol: tcp + - type: port + params: + port: 9093 + timeout: 20 + protocol: tcp diff --git a/examples/systemd/kafka.service b/examples/systemd/kafka.service new file mode 100644 index 0000000..c420f30 --- /dev/null +++ b/examples/systemd/kafka.service @@ -0,0 +1,10 @@ +[Unit] +Description=Test Sleep + +[Service] +Type=notify +ExecStart=bin/ad-runtime-utils --config configs/config.yaml --service kafka --runtime java --start --supervise +TimeoutStartSec=120 + +[Install] +WantedBy=multi-user.target diff --git a/go.mod b/go.mod index 306a510..7ca13e7 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/arenadata/ad-runtime-utils go 1.24.4 -require github.com/goccy/go-yaml v1.18.0 +require ( + github.com/coreos/go-systemd/v22 v22.6.0 + github.com/goccy/go-yaml v1.18.0 +) diff --git a/go.sum b/go.sum index eb0d822..6de895e 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ +github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= +github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= diff --git a/internal/config/config.go b/internal/config/config.go index 172e661..6821a83 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "os" + "strings" "github.com/goccy/go-yaml" ) @@ -14,9 +15,19 @@ type RuntimeSetting struct { Paths []string `yaml:"paths,omitempty"` } +type HealthCheckConfig struct { + Type string `yaml:"type"` + Params map[string]any `yaml:"params,omitempty"` +} + type ServiceConfig struct { - Runtimes map[string]RuntimeSetting `yaml:"runtimes,omitempty"` - Path string `yaml:"path,omitempty"` + Runtimes map[string]RuntimeSetting `yaml:"runtimes,omitempty"` + Path string `yaml:"path,omitempty"` + Executable string `yaml:"executable,omitempty"` + ExecutableArgs []string `yaml:"executable_args,omitempty"` + EnvVars map[string]string `yaml:"env_vars,omitempty"` + EnvVarsFile string `yaml:"env_vars_file,omitempty"` + HealthChecks []HealthCheckConfig `yaml:"health_checks,omitempty"` } type Config struct { @@ -31,6 +42,30 @@ type Config struct { Services map[string]ServiceConfig `yaml:"services"` } +func (h *HealthCheckConfig) ParamToInt(name string) (int, error) { + val, found := h.Params[name] + if !found { + return 0, fmt.Errorf("health check param %q not found", name) + } + intVal, ok := val.(int) + if !ok { + return 0, fmt.Errorf("health check param %q is not an integer", name) + } + return intVal, nil +} + +func (h *HealthCheckConfig) ParamToString(name string) (string, error) { + val, found := h.Params[name] + if !found { + return "", fmt.Errorf("health check param %q not found", name) + } + strVal, ok := val.(string) + if !ok { + return "", fmt.Errorf("health check param %q is not a string", name) + } + return strVal, nil +} + func Load(path string) (*Config, error) { data, readErr := os.ReadFile(path) if readErr != nil { @@ -44,37 +79,70 @@ func Load(path string) (*Config, error) { } for name, svc := range cfg.Services { - if svc.Path == "" { - continue - } - - extData, readExtErr := os.ReadFile(svc.Path) - if readExtErr != nil { - if os.IsNotExist(readExtErr) { - continue + // Load Service config from file if specified + if svc.Path != "" { + if extCfg, fullCfgErr := parseExternalConfig(svc.Path); fullCfgErr == nil { + if extSvc, found := extCfg.Services[name]; found { + cfg.Services[name] = extSvc + continue + } } - return nil, fmt.Errorf("read service config %q: %w", svc.Path, readExtErr) - } - var extCfg Config - fullCfgErr := yaml.UnmarshalWithOptions(extData, &extCfg, yaml.Strict()) - if fullCfgErr == nil { - if extSvc, found := extCfg.Services[name]; found && len(extSvc.Runtimes) > 0 { - svc.Runtimes = extSvc.Runtimes - cfg.Services[name] = svc - continue + // Load Service from external file if specified + ext, err := parseExternalServiceConfig(svc.Path) + if err != nil { + return nil, fmt.Errorf("parse external service config %q: %w", svc.Path, err) } + // Replace ServiceConfig with the loaded one + svc = ext + // Keep the path to the Service config file for future references + svc.Path = path } - var ext ServiceConfig - fallbackErr := yaml.UnmarshalWithOptions(extData, &ext, yaml.Strict()) - if fallbackErr != nil { - return nil, fmt.Errorf("parse service config %q: %w", svc.Path, fallbackErr) - } - - svc.Runtimes = ext.Runtimes cfg.Services[name] = svc } return &cfg, nil } + +func parseExternalConfig(path string) (Config, error) { + var cfg Config + + extData, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return cfg, nil + } + return cfg, fmt.Errorf("read service config %q: %w", path, err) + } + + if err = yaml.UnmarshalWithOptions(extData, &cfg, yaml.Strict()); err == nil { + return cfg, nil + } + return cfg, nil +} + +func parseExternalServiceConfig(path string) (ServiceConfig, error) { + var cfg ServiceConfig + extData, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return cfg, nil + } + return cfg, fmt.Errorf("read service config %q: %w", path, err) + } + if err = yaml.UnmarshalWithOptions(extData, &cfg, yaml.Strict()); err != nil { + return cfg, fmt.Errorf("parse service config %q: %w", path, err) + } + // Abbomination to support env files sourcing + // It changes exec to bash and adds source command and the original executable as the args to bash + if cfg.EnvVarsFile != "" { + argsString := strings.Join(cfg.ExecutableArgs, " ") + cfg.ExecutableArgs = []string{ + "-c", + fmt.Sprintf("source %s; %s %s", cfg.EnvVarsFile, cfg.Executable, argsString), + } + cfg.Executable = "bash" + } + return cfg, nil +} diff --git a/internal/exec/check.go b/internal/exec/check.go new file mode 100644 index 0000000..ce74ed9 --- /dev/null +++ b/internal/exec/check.go @@ -0,0 +1,115 @@ +package exec + +import ( + "fmt" + "os" + "strconv" + "time" + + "github.com/arenadata/ad-runtime-utils/internal/config" +) + +const ( + PortHealthCheckType = "port" + PortHealthCheckTimeoutParamName = "timeout" + PortHealthCheckTimeoutDefault = 60 + PortHealthCheckPortParamName = "port" + PortHealthCheckProtocolParamName = "protocol" + PortHealthCheckProtocolDefault = TCP +) + +type HealthCheck interface { + Check() error +} + +// PortHealthCheck checks that a given port is open by the process with a given PID. +type PortHealthCheck struct { + Port int + Timeout int + PID int + Config config.HealthCheckConfig + Protocol SocketProtocol +} + +func (h *PortHealthCheck) Check() error { + var infos []SocketInfo + var err error + + if err = h.parseConfig(); err != nil { + return err + } + + endTime := time.Now().Add(time.Duration(h.Timeout) * time.Second) + for time.Now().Before(endTime) { + switch h.Protocol { + case TCP, TCP6: + if infos, err = GetTCPSocketsForPid(h.PID); err != nil { + fmt.Fprintf(os.Stderr, "Error getting TCP sockets for PID, retrying: %v\n", err) + } + case UDP, UDP6: + if infos, err = GetUDPSocketsForPid(h.PID); err != nil { + fmt.Fprintf(os.Stderr, "Error getting UDP sockets for PID, retrying: %v\n", err) + } + default: + return fmt.Errorf("invalid socket protocol: %s", h.Protocol) + } + for _, info := range infos { + if info.Port == h.Port { + return nil + } + } + time.Sleep(1 * time.Second) + } + return fmt.Errorf("port %d not open after %d seconds", h.Port, h.Timeout) +} + +func (h *PortHealthCheck) parseConfig() error { + // Port + _, found := h.Config.Params[PortHealthCheckPortParamName] + if !found { + return fmt.Errorf("missing %s parameter", PortHealthCheckPortParamName) + } + portStr, ok := h.Config.Params[PortHealthCheckPortParamName].(string) + if !ok { + return fmt.Errorf("parameter %s has invalid value", PortHealthCheckPortParamName) + } + var port int + var err error + port, err = strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("parameter %s has invalid value: %w", PortHealthCheckPortParamName, err) + } + h.Port = port + + // Protocol + _, found = h.Config.Params[PortHealthCheckProtocolParamName] + if !found { + h.Protocol = PortHealthCheckProtocolDefault + } else { + var protocol string + protocol, ok = h.Config.Params[PortHealthCheckProtocolParamName].(string) + if !ok { + return fmt.Errorf("parameter %s has invalid value", PortHealthCheckProtocolParamName) + } + h.Protocol = SocketProtocol(protocol) + } + + // Timeout + _, ok = h.Config.Params[PortHealthCheckTimeoutParamName].(string) + if !ok { + h.Timeout = PortHealthCheckTimeoutDefault + } else { + var timeoutStr string + timeoutStr, ok = h.Config.Params[PortHealthCheckTimeoutParamName].(string) + if !ok { + return fmt.Errorf("parameter %s has invalid value", PortHealthCheckTimeoutParamName) + } + var timeout int + timeout, err = strconv.Atoi(timeoutStr) + if err != nil { + return fmt.Errorf("parameter %s has invalid value: %w", PortHealthCheckTimeoutParamName, err) + } + h.Timeout = timeout + } + return nil +} diff --git a/internal/exec/exec.go b/internal/exec/exec.go new file mode 100644 index 0000000..ed5e740 --- /dev/null +++ b/internal/exec/exec.go @@ -0,0 +1,32 @@ +package exec + +import ( + "context" + "fmt" + "os/exec" +) + +// RunExecutableAsync starts the given service with the provided arguments in a non-blocking way. +func RunExecutableAsync(executablePath string, args []string, envVars map[string]string) (*exec.Cmd, error) { + ctx := context.TODO() + cmd := exec.CommandContext(ctx, executablePath, args...) + + // Add environment variables to the command. + for k, v := range envVars { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) + } + + if err := cmd.Start(); err != nil { + return nil, err + } + return cmd, nil +} + +// RunExecutable starts the given service with the provided arguments in a blocking way. +func RunExecutable(executablePath string, args []string, envVars map[string]string) error { + cmd, err := RunExecutableAsync(executablePath, args, envVars) + if err != nil { + return err + } + return cmd.Wait() +} diff --git a/internal/exec/socket.go b/internal/exec/socket.go new file mode 100644 index 0000000..20b1e8a --- /dev/null +++ b/internal/exec/socket.go @@ -0,0 +1,219 @@ +package exec + +import ( + "bufio" + "encoding/hex" + "fmt" + "net" + "os" + "strconv" + "strings" +) + +type SocketProtocol string + +const ( + TCP SocketProtocol = "tcp" + UDP SocketProtocol = "udp" + TCP6 SocketProtocol = "tcp6" + UDP6 SocketProtocol = "udp6" +) + +// SocketInfo represents information about a socket. +// Protocol represents the protocol of the socket (TCP or UDP). +// IP represents the IP address of the socket. +// Port represents the port number of the socket. +type SocketInfo struct { + Protocol SocketProtocol + IP net.IP + Port int +} + +type NetworkSocketStat struct { + SL string + LocalAddress string + RemoteAddress string + State string + Queue string + Timer string + Retransmits string + UID string + Timeout string + Inode string +} + +func getInodeForPid(pid int) (map[string]int, error) { + // Get file descriptors for the pid. + fdLink := fmt.Sprintf("/proc/%d/fd", pid) + fdDir, err := os.Open(fdLink) + if err != nil { + return nil, err + } + fdEntries, err := fdDir.Readdirnames(-1) + if err != nil { + return nil, err + } + if cerr := fdDir.Close(); cerr != nil { + return nil, cerr + } + inodes := make(map[string]int) + for _, fdEntry := range fdEntries { + var fdPath string + fdPath, err = os.Readlink(fmt.Sprintf("/proc/%d/fd/%s", pid, fdEntry)) + // File descriptor might be closed already, so skip it + if err != nil { + continue + } + // Check if file descriptor is a socket + if strings.HasPrefix(fdPath, "socket:[") { + inode := strings.TrimPrefix(strings.TrimSuffix(fdPath, "]"), "socket:[") + inodes[inode] = pid + } + } + return inodes, nil +} + +func parseNetworkStat(statFilePath string) ([]NetworkSocketStat, error) { + tcpStatFile, err := os.Open(statFilePath) + if err != nil { + return nil, err + } + defer tcpStatFile.Close() + + var stats []NetworkSocketStat + scanner := bufio.NewScanner(tcpStatFile) + // Skip the header and return if there are no entries + if !scanner.Scan() { + return stats, nil + } + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + fields := strings.Fields(line) + // Skip invalid fields + tcpStatFieldsCount := 10 + if len(fields) < tcpStatFieldsCount { + continue + } + stats = append(stats, NetworkSocketStat{ + SL: fields[0], + LocalAddress: fields[1], + RemoteAddress: fields[2], + State: fields[3], + Queue: fields[4], + Timer: fields[5], + Retransmits: fields[6], + UID: fields[7], + Timeout: fields[8], + Inode: fields[9], + }) + } + return stats, scanner.Err() +} + +func getSocketsForInodes(inodes map[string]int, protocol SocketProtocol) ([]SocketInfo, error) { + var statFilePath string + switch protocol { + case TCP: + statFilePath = "/proc/net/tcp" + case UDP: + statFilePath = "/proc/net/udp" + case TCP6: + statFilePath = "/proc/net/tcp6" + case UDP6: + statFilePath = "/proc/net/udp6" + default: + return nil, fmt.Errorf("invalid socket protocol: %q", protocol) + } + + netStats, err := parseNetworkStat(statFilePath) + if err != nil { + return nil, err + } + + var sockets []SocketInfo + for _, stat := range netStats { + // Hex IP: addressParts[0], Hex Port AddressPart[1] + addressFieldsCount := 2 + addressParts := strings.Split(stat.LocalAddress, ":") + if len(addressParts) != addressFieldsCount { + continue + } + + var ipBytes []byte + ipBytes, err = hex.DecodeString(addressParts[0]) + if err != nil { + continue + } + // Convert hex IP to the readable format + if len(ipBytes) != 4 && len(ipBytes) != 16 { + continue + } + ip := net.IP(ipBytes) + // Convert hex port to int + var port int64 + port, err = strconv.ParseInt(addressParts[1], 16, 32) + if err != nil { + continue + } + // Check if the socket is in our inodes list, inode is fields[9] + if _, exists := inodes[stat.Inode]; exists { + sockets = append(sockets, SocketInfo{ + Protocol: protocol, + IP: ip, + Port: int(port), + }) + } + } + return sockets, nil +} + +// GetTCPSocketsForPid returns a list of Listening TCP sockets for the given process ID. +func GetTCPSocketsForPid(pid int) ([]SocketInfo, error) { + inodes, err := getInodeForPid(pid) + if err != nil { + return nil, err + } + sockets, err := getSocketsForInodes(inodes, TCP) + if err != nil { + return nil, err + } + sockets6, err := getSocketsForInodes(inodes, TCP6) + if err != nil { + return nil, err + } + return append(sockets, sockets6...), nil +} + +// GetUDPSocketsForPid returns a list of Listening UDP sockets for the given process ID. +func GetUDPSocketsForPid(pid int) ([]SocketInfo, error) { + inodes, err := getInodeForPid(pid) + if err != nil { + return nil, err + } + sockets, err := getSocketsForInodes(inodes, UDP) + if err != nil { + return nil, err + } + sockets6, err := getSocketsForInodes(inodes, UDP6) + if err != nil { + return nil, err + } + return append(sockets, sockets6...), nil +} + +// GetSocketsForPid returns a list of all Listening sockets for the given process ID. +func GetSocketsForPid(pid int) ([]SocketInfo, error) { + tcpSockets, err := GetTCPSocketsForPid(pid) + if err != nil { + return nil, err + } + udpSockets, err := GetUDPSocketsForPid(pid) + if err != nil { + return nil, err + } + return append(tcpSockets, udpSockets...), nil +} diff --git a/internal/exec/socket_test.go b/internal/exec/socket_test.go new file mode 100644 index 0000000..f362737 --- /dev/null +++ b/internal/exec/socket_test.go @@ -0,0 +1,43 @@ +package exec + +import ( + "net" + "os" + "testing" +) + +func startSimpleServer(t *testing.T, protocol string) (int, error) { + // Start a simple TCP server and return the port it's listening on. + listener, err := net.Listen(protocol, ":0") + if err != nil { + return 0, err + } + go func() { + conn, lerr := listener.Accept() + if lerr != nil { + t.Error(lerr) + return + } + conn.Close() + }() + return listener.Addr().(*net.TCPAddr).Port, nil +} + +func TestGetTCPSocketsForPid(t *testing.T) { + port, err := startSimpleServer(t, string(TCP)) + if err != nil { + t.Error(err) + } + pid := os.Getpid() + sockets, err := GetTCPSocketsForPid(pid) + if err != nil { + t.Error(err) + return + } + for _, soc := range sockets { + if soc.Port == port { + return + } + } + t.Error("Expected to find a socket listening on the port of the started server") +}