From 0f4fc8fbab01f95503fda9d76197fc75357f41db Mon Sep 17 00:00:00 2001
From: "p.zindyaev"
Date: Thu, 6 Nov 2025 15:40:15 +0300
Subject: [PATCH] feature(systemd): added the ability to start a service with
the predefined config
---
.gitignore | 1 +
README.MD | 13 +-
cmd/ad-runtime-utils/app.go | 63 ++++++++++
configs/config.yaml | 2 +-
examples/config/kafka.yaml | 17 +++
examples/systemd/kafka.service | 10 ++
go.mod | 5 +-
go.sum | 2 +
internal/config/config.go | 120 ++++++++++++++----
internal/exec/check.go | 115 +++++++++++++++++
internal/exec/exec.go | 32 +++++
internal/exec/socket.go | 219 +++++++++++++++++++++++++++++++++
internal/exec/socket_test.go | 43 +++++++
13 files changed, 613 insertions(+), 29 deletions(-)
create mode 100644 examples/config/kafka.yaml
create mode 100644 examples/systemd/kafka.service
create mode 100644 internal/exec/check.go
create mode 100644 internal/exec/exec.go
create mode 100644 internal/exec/socket.go
create mode 100644 internal/exec/socket_test.go
diff --git a/.gitignore b/.gitignore
index 69dcb52..f6ccf46 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
# Added by goreleaser init:
dist/
+cmd/ad-runtime-utils/ad-runtime-utils
diff --git a/README.MD b/README.MD
index aa82650..2bd1629 100644
--- a/README.MD
+++ b/README.MD
@@ -91,4 +91,15 @@ Default runtimes:
Service :
: ">
…
-```
\ No newline at end of file
+```
+
+### 4. Starting a Service
+
+When --start and/or --supervise flags are provided --service is a required argument.
+
+- It starts the executable `services..executable` with the specified arguments from `services..executable_args`.
+
+- If the `--supervise` flag is provided, it will run the health checks, defined in `services..health_checks` on service start to make sure the service is operational. If any of the checks fail the service will be stopped. `Type=notify` should be used in the systemd unit(see example in `examples/systemd` directory).
+
+- While using `--supervise` and health checks, make sure that systemd service has enough `TimeoutStartSec`. ideally should be a combined timeout of all health checks.
+
diff --git a/cmd/ad-runtime-utils/app.go b/cmd/ad-runtime-utils/app.go
index d6ba747..a512ef3 100644
--- a/cmd/ad-runtime-utils/app.go
+++ b/cmd/ad-runtime-utils/app.go
@@ -4,10 +4,13 @@ import (
"flag"
"fmt"
"io"
+ "os"
"strings"
"github.com/arenadata/ad-runtime-utils/internal/config"
"github.com/arenadata/ad-runtime-utils/internal/detect"
+ "github.com/arenadata/ad-runtime-utils/internal/exec"
+ "github.com/coreos/go-systemd/v22/daemon"
)
// exit codes.
@@ -27,6 +30,8 @@ func Run(args []string, stdout, stderr io.Writer) int {
listAll := fs.Bool("list", false, "List all detected runtimes (default + services)")
fs.BoolVar(listAll, "l", false, "shorthand for --list")
printCACerts := fs.Bool("print-cacerts", false, "When used with --runtime=java, prints the cacerts path and exits")
+ start := fs.Bool("start", false, "Start the service. Use with simple/exec services")
+ supervise := fs.Bool("supervise", false, "Supervise the service. Use with notify systemd services")
if err := fs.Parse(args); err != nil {
return exitParseError
@@ -76,6 +81,15 @@ func Run(args []string, stdout, stderr io.Writer) int {
}
envName := detectEnvName(cfg, *service, *runtime)
+
+ if *start {
+ if err = startService(*service, envName, path, *cfg, *supervise); err != nil {
+ fmt.Fprintf(stderr, "start service failed: %v\n", err)
+ return exitUserError
+ }
+ return exitOK
+ }
+
fmt.Fprintf(stdout, "export %s=%s\n", envName, path)
return exitOK
}
@@ -124,3 +138,52 @@ func runList(cfg *config.Config, stdout, stderr io.Writer) int {
}
return exitOK
}
+
+func startService(service string, envName string, envPath string, cfg config.Config, supervise bool) error {
+ srvConfig, ok := cfg.Services[service]
+ if !ok {
+ return fmt.Errorf("service %s not found in config", service)
+ }
+ // Append the env for the runtime (eg. JAVA_HOME)
+ if srvConfig.EnvVars == nil {
+ srvConfig.EnvVars = make(map[string]string)
+ }
+ srvConfig.EnvVars[envName] = envPath
+ if !supervise {
+ return exec.RunExecutable(srvConfig.Executable, srvConfig.ExecutableArgs, srvConfig.EnvVars)
+ }
+ process, err := exec.RunExecutableAsync(srvConfig.Executable, srvConfig.ExecutableArgs, srvConfig.EnvVars)
+ if err != nil {
+ return err
+ }
+ // Run the health checks
+ for _, checkCfg := range srvConfig.HealthChecks {
+ switch checkCfg.Type {
+ case exec.PortHealthCheckType:
+ portheck := exec.PortHealthCheck{
+ PID: process.Process.Pid,
+ Config: checkCfg,
+ }
+ if err = portheck.Check(); err != nil {
+ if err = process.Process.Signal(os.Interrupt); err != nil {
+ return fmt.Errorf("failed to send interrupt signal to process: %w", err)
+ }
+ return fmt.Errorf("health check failed: %w", err)
+ }
+ default:
+ if err = process.Process.Signal(os.Interrupt); err != nil {
+ return fmt.Errorf("failed to send interrupt signal to process: %w", err)
+ }
+ return fmt.Errorf("unknown health check type: %s", checkCfg.Type)
+ }
+ }
+ // Notify systemd daemon that service has started
+ if _, err = daemon.SdNotify(false, daemon.SdNotifyReady); err != nil {
+ fmt.Fprintf(os.Stderr, "systemd notification failed: %v\n", err)
+ }
+ // TODO: Replace this with an actual supervisor loop
+ if err = process.Wait(); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/configs/config.yaml b/configs/config.yaml
index 0641ea1..0121bd6 100644
--- a/configs/config.yaml
+++ b/configs/config.yaml
@@ -91,4 +91,4 @@ services:
FLINK2:
path: /etc/flink2/conf/flink2-java.yaml
OZONE:
- path: /etc/ozone/conf/ozone-java.yaml
\ No newline at end of file
+ path: /etc/ozone/conf/ozone-java.yaml
diff --git a/examples/config/kafka.yaml b/examples/config/kafka.yaml
new file mode 100644
index 0000000..65cb6d5
--- /dev/null
+++ b/examples/config/kafka.yaml
@@ -0,0 +1,17 @@
+executable: bin/kafka-server-start.sh
+executable_args:
+ - config/server.properties
+runtimes:
+ java:
+ version: "21"
+health_checks:
+ - type: port
+ params:
+ port: 9092
+ timeout: 20
+ protocol: tcp
+ - type: port
+ params:
+ port: 9093
+ timeout: 20
+ protocol: tcp
diff --git a/examples/systemd/kafka.service b/examples/systemd/kafka.service
new file mode 100644
index 0000000..c420f30
--- /dev/null
+++ b/examples/systemd/kafka.service
@@ -0,0 +1,10 @@
+[Unit]
+Description=Test Sleep
+
+[Service]
+Type=notify
+ExecStart=bin/ad-runtime-utils --config configs/config.yaml --service kafka --runtime java --start --supervise
+TimeoutStartSec=120
+
+[Install]
+WantedBy=multi-user.target
diff --git a/go.mod b/go.mod
index 306a510..7ca13e7 100644
--- a/go.mod
+++ b/go.mod
@@ -2,4 +2,7 @@ module github.com/arenadata/ad-runtime-utils
go 1.24.4
-require github.com/goccy/go-yaml v1.18.0
+require (
+ github.com/coreos/go-systemd/v22 v22.6.0
+ github.com/goccy/go-yaml v1.18.0
+)
diff --git a/go.sum b/go.sum
index eb0d822..6de895e 100644
--- a/go.sum
+++ b/go.sum
@@ -1,2 +1,4 @@
+github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo=
+github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
diff --git a/internal/config/config.go b/internal/config/config.go
index 172e661..6821a83 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -3,6 +3,7 @@ package config
import (
"fmt"
"os"
+ "strings"
"github.com/goccy/go-yaml"
)
@@ -14,9 +15,19 @@ type RuntimeSetting struct {
Paths []string `yaml:"paths,omitempty"`
}
+type HealthCheckConfig struct {
+ Type string `yaml:"type"`
+ Params map[string]any `yaml:"params,omitempty"`
+}
+
type ServiceConfig struct {
- Runtimes map[string]RuntimeSetting `yaml:"runtimes,omitempty"`
- Path string `yaml:"path,omitempty"`
+ Runtimes map[string]RuntimeSetting `yaml:"runtimes,omitempty"`
+ Path string `yaml:"path,omitempty"`
+ Executable string `yaml:"executable,omitempty"`
+ ExecutableArgs []string `yaml:"executable_args,omitempty"`
+ EnvVars map[string]string `yaml:"env_vars,omitempty"`
+ EnvVarsFile string `yaml:"env_vars_file,omitempty"`
+ HealthChecks []HealthCheckConfig `yaml:"health_checks,omitempty"`
}
type Config struct {
@@ -31,6 +42,30 @@ type Config struct {
Services map[string]ServiceConfig `yaml:"services"`
}
+func (h *HealthCheckConfig) ParamToInt(name string) (int, error) {
+ val, found := h.Params[name]
+ if !found {
+ return 0, fmt.Errorf("health check param %q not found", name)
+ }
+ intVal, ok := val.(int)
+ if !ok {
+ return 0, fmt.Errorf("health check param %q is not an integer", name)
+ }
+ return intVal, nil
+}
+
+func (h *HealthCheckConfig) ParamToString(name string) (string, error) {
+ val, found := h.Params[name]
+ if !found {
+ return "", fmt.Errorf("health check param %q not found", name)
+ }
+ strVal, ok := val.(string)
+ if !ok {
+ return "", fmt.Errorf("health check param %q is not a string", name)
+ }
+ return strVal, nil
+}
+
func Load(path string) (*Config, error) {
data, readErr := os.ReadFile(path)
if readErr != nil {
@@ -44,37 +79,70 @@ func Load(path string) (*Config, error) {
}
for name, svc := range cfg.Services {
- if svc.Path == "" {
- continue
- }
-
- extData, readExtErr := os.ReadFile(svc.Path)
- if readExtErr != nil {
- if os.IsNotExist(readExtErr) {
- continue
+ // Load Service config from file if specified
+ if svc.Path != "" {
+ if extCfg, fullCfgErr := parseExternalConfig(svc.Path); fullCfgErr == nil {
+ if extSvc, found := extCfg.Services[name]; found {
+ cfg.Services[name] = extSvc
+ continue
+ }
}
- return nil, fmt.Errorf("read service config %q: %w", svc.Path, readExtErr)
- }
- var extCfg Config
- fullCfgErr := yaml.UnmarshalWithOptions(extData, &extCfg, yaml.Strict())
- if fullCfgErr == nil {
- if extSvc, found := extCfg.Services[name]; found && len(extSvc.Runtimes) > 0 {
- svc.Runtimes = extSvc.Runtimes
- cfg.Services[name] = svc
- continue
+ // Load Service from external file if specified
+ ext, err := parseExternalServiceConfig(svc.Path)
+ if err != nil {
+ return nil, fmt.Errorf("parse external service config %q: %w", svc.Path, err)
}
+ // Replace ServiceConfig with the loaded one
+ svc = ext
+ // Keep the path to the Service config file for future references
+ svc.Path = path
}
- var ext ServiceConfig
- fallbackErr := yaml.UnmarshalWithOptions(extData, &ext, yaml.Strict())
- if fallbackErr != nil {
- return nil, fmt.Errorf("parse service config %q: %w", svc.Path, fallbackErr)
- }
-
- svc.Runtimes = ext.Runtimes
cfg.Services[name] = svc
}
return &cfg, nil
}
+
+func parseExternalConfig(path string) (Config, error) {
+ var cfg Config
+
+ extData, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return cfg, nil
+ }
+ return cfg, fmt.Errorf("read service config %q: %w", path, err)
+ }
+
+ if err = yaml.UnmarshalWithOptions(extData, &cfg, yaml.Strict()); err == nil {
+ return cfg, nil
+ }
+ return cfg, nil
+}
+
+func parseExternalServiceConfig(path string) (ServiceConfig, error) {
+ var cfg ServiceConfig
+ extData, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return cfg, nil
+ }
+ return cfg, fmt.Errorf("read service config %q: %w", path, err)
+ }
+ if err = yaml.UnmarshalWithOptions(extData, &cfg, yaml.Strict()); err != nil {
+ return cfg, fmt.Errorf("parse service config %q: %w", path, err)
+ }
+ // Abbomination to support env files sourcing
+ // It changes exec to bash and adds source command and the original executable as the args to bash
+ if cfg.EnvVarsFile != "" {
+ argsString := strings.Join(cfg.ExecutableArgs, " ")
+ cfg.ExecutableArgs = []string{
+ "-c",
+ fmt.Sprintf("source %s; %s %s", cfg.EnvVarsFile, cfg.Executable, argsString),
+ }
+ cfg.Executable = "bash"
+ }
+ return cfg, nil
+}
diff --git a/internal/exec/check.go b/internal/exec/check.go
new file mode 100644
index 0000000..ce74ed9
--- /dev/null
+++ b/internal/exec/check.go
@@ -0,0 +1,115 @@
+package exec
+
+import (
+ "fmt"
+ "os"
+ "strconv"
+ "time"
+
+ "github.com/arenadata/ad-runtime-utils/internal/config"
+)
+
+const (
+ PortHealthCheckType = "port"
+ PortHealthCheckTimeoutParamName = "timeout"
+ PortHealthCheckTimeoutDefault = 60
+ PortHealthCheckPortParamName = "port"
+ PortHealthCheckProtocolParamName = "protocol"
+ PortHealthCheckProtocolDefault = TCP
+)
+
+type HealthCheck interface {
+ Check() error
+}
+
+// PortHealthCheck checks that a given port is open by the process with a given PID.
+type PortHealthCheck struct {
+ Port int
+ Timeout int
+ PID int
+ Config config.HealthCheckConfig
+ Protocol SocketProtocol
+}
+
+func (h *PortHealthCheck) Check() error {
+ var infos []SocketInfo
+ var err error
+
+ if err = h.parseConfig(); err != nil {
+ return err
+ }
+
+ endTime := time.Now().Add(time.Duration(h.Timeout) * time.Second)
+ for time.Now().Before(endTime) {
+ switch h.Protocol {
+ case TCP, TCP6:
+ if infos, err = GetTCPSocketsForPid(h.PID); err != nil {
+ fmt.Fprintf(os.Stderr, "Error getting TCP sockets for PID, retrying: %v\n", err)
+ }
+ case UDP, UDP6:
+ if infos, err = GetUDPSocketsForPid(h.PID); err != nil {
+ fmt.Fprintf(os.Stderr, "Error getting UDP sockets for PID, retrying: %v\n", err)
+ }
+ default:
+ return fmt.Errorf("invalid socket protocol: %s", h.Protocol)
+ }
+ for _, info := range infos {
+ if info.Port == h.Port {
+ return nil
+ }
+ }
+ time.Sleep(1 * time.Second)
+ }
+ return fmt.Errorf("port %d not open after %d seconds", h.Port, h.Timeout)
+}
+
+func (h *PortHealthCheck) parseConfig() error {
+ // Port
+ _, found := h.Config.Params[PortHealthCheckPortParamName]
+ if !found {
+ return fmt.Errorf("missing %s parameter", PortHealthCheckPortParamName)
+ }
+ portStr, ok := h.Config.Params[PortHealthCheckPortParamName].(string)
+ if !ok {
+ return fmt.Errorf("parameter %s has invalid value", PortHealthCheckPortParamName)
+ }
+ var port int
+ var err error
+ port, err = strconv.Atoi(portStr)
+ if err != nil {
+ return fmt.Errorf("parameter %s has invalid value: %w", PortHealthCheckPortParamName, err)
+ }
+ h.Port = port
+
+ // Protocol
+ _, found = h.Config.Params[PortHealthCheckProtocolParamName]
+ if !found {
+ h.Protocol = PortHealthCheckProtocolDefault
+ } else {
+ var protocol string
+ protocol, ok = h.Config.Params[PortHealthCheckProtocolParamName].(string)
+ if !ok {
+ return fmt.Errorf("parameter %s has invalid value", PortHealthCheckProtocolParamName)
+ }
+ h.Protocol = SocketProtocol(protocol)
+ }
+
+ // Timeout
+ _, ok = h.Config.Params[PortHealthCheckTimeoutParamName].(string)
+ if !ok {
+ h.Timeout = PortHealthCheckTimeoutDefault
+ } else {
+ var timeoutStr string
+ timeoutStr, ok = h.Config.Params[PortHealthCheckTimeoutParamName].(string)
+ if !ok {
+ return fmt.Errorf("parameter %s has invalid value", PortHealthCheckTimeoutParamName)
+ }
+ var timeout int
+ timeout, err = strconv.Atoi(timeoutStr)
+ if err != nil {
+ return fmt.Errorf("parameter %s has invalid value: %w", PortHealthCheckTimeoutParamName, err)
+ }
+ h.Timeout = timeout
+ }
+ return nil
+}
diff --git a/internal/exec/exec.go b/internal/exec/exec.go
new file mode 100644
index 0000000..ed5e740
--- /dev/null
+++ b/internal/exec/exec.go
@@ -0,0 +1,32 @@
+package exec
+
+import (
+ "context"
+ "fmt"
+ "os/exec"
+)
+
+// RunExecutableAsync starts the given service with the provided arguments in a non-blocking way.
+func RunExecutableAsync(executablePath string, args []string, envVars map[string]string) (*exec.Cmd, error) {
+ ctx := context.TODO()
+ cmd := exec.CommandContext(ctx, executablePath, args...)
+
+ // Add environment variables to the command.
+ for k, v := range envVars {
+ cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
+ }
+
+ if err := cmd.Start(); err != nil {
+ return nil, err
+ }
+ return cmd, nil
+}
+
+// RunExecutable starts the given service with the provided arguments in a blocking way.
+func RunExecutable(executablePath string, args []string, envVars map[string]string) error {
+ cmd, err := RunExecutableAsync(executablePath, args, envVars)
+ if err != nil {
+ return err
+ }
+ return cmd.Wait()
+}
diff --git a/internal/exec/socket.go b/internal/exec/socket.go
new file mode 100644
index 0000000..20b1e8a
--- /dev/null
+++ b/internal/exec/socket.go
@@ -0,0 +1,219 @@
+package exec
+
+import (
+ "bufio"
+ "encoding/hex"
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+)
+
+type SocketProtocol string
+
+const (
+ TCP SocketProtocol = "tcp"
+ UDP SocketProtocol = "udp"
+ TCP6 SocketProtocol = "tcp6"
+ UDP6 SocketProtocol = "udp6"
+)
+
+// SocketInfo represents information about a socket.
+// Protocol represents the protocol of the socket (TCP or UDP).
+// IP represents the IP address of the socket.
+// Port represents the port number of the socket.
+type SocketInfo struct {
+ Protocol SocketProtocol
+ IP net.IP
+ Port int
+}
+
+type NetworkSocketStat struct {
+ SL string
+ LocalAddress string
+ RemoteAddress string
+ State string
+ Queue string
+ Timer string
+ Retransmits string
+ UID string
+ Timeout string
+ Inode string
+}
+
+func getInodeForPid(pid int) (map[string]int, error) {
+ // Get file descriptors for the pid.
+ fdLink := fmt.Sprintf("/proc/%d/fd", pid)
+ fdDir, err := os.Open(fdLink)
+ if err != nil {
+ return nil, err
+ }
+ fdEntries, err := fdDir.Readdirnames(-1)
+ if err != nil {
+ return nil, err
+ }
+ if cerr := fdDir.Close(); cerr != nil {
+ return nil, cerr
+ }
+ inodes := make(map[string]int)
+ for _, fdEntry := range fdEntries {
+ var fdPath string
+ fdPath, err = os.Readlink(fmt.Sprintf("/proc/%d/fd/%s", pid, fdEntry))
+ // File descriptor might be closed already, so skip it
+ if err != nil {
+ continue
+ }
+ // Check if file descriptor is a socket
+ if strings.HasPrefix(fdPath, "socket:[") {
+ inode := strings.TrimPrefix(strings.TrimSuffix(fdPath, "]"), "socket:[")
+ inodes[inode] = pid
+ }
+ }
+ return inodes, nil
+}
+
+func parseNetworkStat(statFilePath string) ([]NetworkSocketStat, error) {
+ tcpStatFile, err := os.Open(statFilePath)
+ if err != nil {
+ return nil, err
+ }
+ defer tcpStatFile.Close()
+
+ var stats []NetworkSocketStat
+ scanner := bufio.NewScanner(tcpStatFile)
+ // Skip the header and return if there are no entries
+ if !scanner.Scan() {
+ return stats, nil
+ }
+
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" {
+ continue
+ }
+ fields := strings.Fields(line)
+ // Skip invalid fields
+ tcpStatFieldsCount := 10
+ if len(fields) < tcpStatFieldsCount {
+ continue
+ }
+ stats = append(stats, NetworkSocketStat{
+ SL: fields[0],
+ LocalAddress: fields[1],
+ RemoteAddress: fields[2],
+ State: fields[3],
+ Queue: fields[4],
+ Timer: fields[5],
+ Retransmits: fields[6],
+ UID: fields[7],
+ Timeout: fields[8],
+ Inode: fields[9],
+ })
+ }
+ return stats, scanner.Err()
+}
+
+func getSocketsForInodes(inodes map[string]int, protocol SocketProtocol) ([]SocketInfo, error) {
+ var statFilePath string
+ switch protocol {
+ case TCP:
+ statFilePath = "/proc/net/tcp"
+ case UDP:
+ statFilePath = "/proc/net/udp"
+ case TCP6:
+ statFilePath = "/proc/net/tcp6"
+ case UDP6:
+ statFilePath = "/proc/net/udp6"
+ default:
+ return nil, fmt.Errorf("invalid socket protocol: %q", protocol)
+ }
+
+ netStats, err := parseNetworkStat(statFilePath)
+ if err != nil {
+ return nil, err
+ }
+
+ var sockets []SocketInfo
+ for _, stat := range netStats {
+ // Hex IP: addressParts[0], Hex Port AddressPart[1]
+ addressFieldsCount := 2
+ addressParts := strings.Split(stat.LocalAddress, ":")
+ if len(addressParts) != addressFieldsCount {
+ continue
+ }
+
+ var ipBytes []byte
+ ipBytes, err = hex.DecodeString(addressParts[0])
+ if err != nil {
+ continue
+ }
+ // Convert hex IP to the readable format
+ if len(ipBytes) != 4 && len(ipBytes) != 16 {
+ continue
+ }
+ ip := net.IP(ipBytes)
+ // Convert hex port to int
+ var port int64
+ port, err = strconv.ParseInt(addressParts[1], 16, 32)
+ if err != nil {
+ continue
+ }
+ // Check if the socket is in our inodes list, inode is fields[9]
+ if _, exists := inodes[stat.Inode]; exists {
+ sockets = append(sockets, SocketInfo{
+ Protocol: protocol,
+ IP: ip,
+ Port: int(port),
+ })
+ }
+ }
+ return sockets, nil
+}
+
+// GetTCPSocketsForPid returns a list of Listening TCP sockets for the given process ID.
+func GetTCPSocketsForPid(pid int) ([]SocketInfo, error) {
+ inodes, err := getInodeForPid(pid)
+ if err != nil {
+ return nil, err
+ }
+ sockets, err := getSocketsForInodes(inodes, TCP)
+ if err != nil {
+ return nil, err
+ }
+ sockets6, err := getSocketsForInodes(inodes, TCP6)
+ if err != nil {
+ return nil, err
+ }
+ return append(sockets, sockets6...), nil
+}
+
+// GetUDPSocketsForPid returns a list of Listening UDP sockets for the given process ID.
+func GetUDPSocketsForPid(pid int) ([]SocketInfo, error) {
+ inodes, err := getInodeForPid(pid)
+ if err != nil {
+ return nil, err
+ }
+ sockets, err := getSocketsForInodes(inodes, UDP)
+ if err != nil {
+ return nil, err
+ }
+ sockets6, err := getSocketsForInodes(inodes, UDP6)
+ if err != nil {
+ return nil, err
+ }
+ return append(sockets, sockets6...), nil
+}
+
+// GetSocketsForPid returns a list of all Listening sockets for the given process ID.
+func GetSocketsForPid(pid int) ([]SocketInfo, error) {
+ tcpSockets, err := GetTCPSocketsForPid(pid)
+ if err != nil {
+ return nil, err
+ }
+ udpSockets, err := GetUDPSocketsForPid(pid)
+ if err != nil {
+ return nil, err
+ }
+ return append(tcpSockets, udpSockets...), nil
+}
diff --git a/internal/exec/socket_test.go b/internal/exec/socket_test.go
new file mode 100644
index 0000000..f362737
--- /dev/null
+++ b/internal/exec/socket_test.go
@@ -0,0 +1,43 @@
+package exec
+
+import (
+ "net"
+ "os"
+ "testing"
+)
+
+func startSimpleServer(t *testing.T, protocol string) (int, error) {
+ // Start a simple TCP server and return the port it's listening on.
+ listener, err := net.Listen(protocol, ":0")
+ if err != nil {
+ return 0, err
+ }
+ go func() {
+ conn, lerr := listener.Accept()
+ if lerr != nil {
+ t.Error(lerr)
+ return
+ }
+ conn.Close()
+ }()
+ return listener.Addr().(*net.TCPAddr).Port, nil
+}
+
+func TestGetTCPSocketsForPid(t *testing.T) {
+ port, err := startSimpleServer(t, string(TCP))
+ if err != nil {
+ t.Error(err)
+ }
+ pid := os.Getpid()
+ sockets, err := GetTCPSocketsForPid(pid)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ for _, soc := range sockets {
+ if soc.Port == port {
+ return
+ }
+ }
+ t.Error("Expected to find a socket listening on the port of the started server")
+}