diff --git a/driver/driver_test.go b/driver/driver_test.go index 78d6f14c..6c8e0b9e 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -160,19 +160,22 @@ func (f *fakeMounter) FindAbsoluteDeviceByIDPath(volumeName string, log *logrus. func (f *fakeMounter) IsFormatted(source string, luksContext LuksContext, log *logrus.Entry) (bool, error) { return true, nil } -func (f *fakeMounter) IsMounted(target string, log *logrus.Entry) (bool, error) { +func (f *fakeMounter) GetMountInfo(target string, log *logrus.Entry) (*MountInfo, error) { f.mu.RLock() defer f.mu.RUnlock() - _, ok := f.mounted[target] - return ok, nil + source, ok := f.mounted[target] + if !ok { + return nil, nil + } + return &MountInfo{Target: target, Source: source, Propagation: "shared"}, nil } func (f *fakeMounter) checkMountPath(path string) (sanity.PathKind, error) { - isMounted, err := f.IsMounted(path, nil) + info, err := f.GetMountInfo(path, nil) if err != nil { return "", err } - if isMounted { + if info != nil { return sanity.PathIsDir, nil } return sanity.PathIsNotFound, nil diff --git a/driver/mounter.go b/driver/mounter.go index 815cbfcd..ca5fd5a2 100644 --- a/driver/mounter.go +++ b/driver/mounter.go @@ -54,6 +54,15 @@ type fileSystem struct { Options string `json:"options"` } +// MountInfo describes a single mount as reported by findmnt. +type MountInfo struct { + Target string + Source string + Propagation string + FsType string + Options string +} + const ( // blkidExitStatusNoIdentifiers defines the exit code returned from blkid indicating that no devices have been found. See http://www.polarhome.com/service/man/?qf=blkid&tf=2&of=Alpinelinux for details. blkidExitStatusNoIdentifiers = 2 @@ -81,10 +90,11 @@ type Mounter interface { // returns true if the source device is already formatted. IsFormatted(source string, luksContext LuksContext, log *logrus.Entry) (bool, error) - // IsMounted checks whether the target path is a correct mount (i.e: - // propagated). It returns true if it's mounted. An error is returned in - // case of system errors or if it's mounted incorrectly. - IsMounted(target string, log *logrus.Entry) (bool, error) + // GetMountInfo returns the mount currently at target, or nil if nothing + // is mounted there. Returns a non-nil error only on lookup failures + // (findmnt missing, JSON parse error, etc.). Callers that require + // correct mount propagation must check info.Propagation themselves. + GetMountInfo(target string, log *logrus.Entry) (*MountInfo, error) // Used to find a path in /dev/disk/by-id with a serial that we have from // the cloudscale API. @@ -263,59 +273,33 @@ func (m *mounter) Unmount(target string, luksContext LuksContext, log *logrus.En return errors.New("target is not specified for unmounting the volume") } - // if this is the unmount call after the mount-bind has been removed, - // a luks volume needs to be closed after unmounting; get the source - // of the mount to check if that is a luks volume - mountSources, err := getMountSources(target) + // Resolve the mounted source before tearing down so we can close any + // LUKS mapping that was backing it. Mount-propagation correctness is + // not Unmount's concern — a misconfigured mount must still be cleaned up. + info, err := m.GetMountInfo(target, log) if err != nil { - return fmt.Errorf("failed to get mount sources for target %q: %v", target, err) + return fmt.Errorf("failed to get mount info for target %q: %v", target, err) } - err = mount.CleanupMountPoint(target, m.kMounter, true) - if err != nil { + if err := mount.CleanupMountPoint(target, m.kMounter, true); err != nil { return err } - // if this is the unstaging process, check if the source is a luks volume and close it - if luksContext.VolumeLifecycle == VolumeLifecycleNodeUnstageVolume { - for _, source := range mountSources { - isLuksMapping, mappingName, err := isLuksMapping(source) - if err != nil { + if luksContext.VolumeLifecycle == VolumeLifecycleNodeUnstageVolume && info != nil { + isLuksMapping, mappingName, err := isLuksMapping(info.Source) + if err != nil { + return err + } + if isLuksMapping { + if err := luksClose(mappingName, log); err != nil { return err } - if isLuksMapping { - err := luksClose(mappingName, log) - if err != nil { - return err - } - } } } return nil } -// gets the mount sources of a mountpoint -func getMountSources(target string) ([]string, error) { - _, err := exec.LookPath("findmnt") - if err != nil { - if errors.Is(err, exec.ErrNotFound) { - return nil, fmt.Errorf("%q executable not found in $PATH", "findmnt") - } - return nil, err - } - out, err := exec.Command("sh", "-c", fmt.Sprintf("findmnt -o SOURCE -n -M %s", target)).CombinedOutput() - if err != nil { - // findmnt exits with non zero exit status if it couldn't find anything - if strings.TrimSpace(string(out)) == "" { - return nil, nil - } - return nil, fmt.Errorf("checking mounted failed: %v cmd: %q output: %q", - err, "findmnt", string(out)) - } - return strings.Split(string(out), "\n"), nil -} - func (m *mounter) IsFormatted(source string, luksContext LuksContext, log *logrus.Entry) (bool, error) { if !luksContext.EncryptionEnabled { return isVolumeFormatted(source, log) @@ -369,21 +353,21 @@ func isVolumeFormatted(source string, log *logrus.Entry) (bool, error) { return true, nil } -func (m *mounter) IsMounted(target string, log *logrus.Entry) (bool, error) { +func (m *mounter) GetMountInfo(target string, log *logrus.Entry) (*MountInfo, error) { if target == "" { - return false, errors.New("target is not specified for checking the mount") + return nil, errors.New("target is not specified for checking the mount") } findmntCmd := "findmnt" _, err := exec.LookPath(findmntCmd) if err != nil { if errors.Is(err, exec.ErrNotFound) { - return false, fmt.Errorf("%q executable not found in $PATH", findmntCmd) + return nil, fmt.Errorf("%q executable not found in $PATH", findmntCmd) } - return false, err + return nil, err } - findmntArgs := []string{"-o", "TARGET,PROPAGATION,FSTYPE,OPTIONS", "-M", target, "-J"} + findmntArgs := []string{"-o", "TARGET,PROPAGATION,FSTYPE,OPTIONS,SOURCE", "-M", target, "-J"} log.WithFields(logrus.Fields{ "cmd": findmntCmd, @@ -392,40 +376,38 @@ func (m *mounter) IsMounted(target string, log *logrus.Entry) (bool, error) { out, err := exec.Command(findmntCmd, findmntArgs...).CombinedOutput() if err != nil { - // findmnt exits with non zero exit status if it couldn't find anything + // findmnt exits with non-zero exit status if it couldn't find anything if strings.TrimSpace(string(out)) == "" { - return false, nil + return nil, nil } - return false, fmt.Errorf("checking mounted failed: %v cmd: %q output: %q", + return nil, fmt.Errorf("checking mounted failed: %v cmd: %q output: %q", err, findmntCmd, string(out)) } - // no response means there is no mount - if string(out) == "" { - return false, nil + if len(out) == 0 { + return nil, nil } var resp *findmntResponse - err = json.Unmarshal(out, &resp) - if err != nil { - return false, fmt.Errorf("couldn't unmarshal data: %q: %s", string(out), err) + if err := json.Unmarshal(out, &resp); err != nil { + return nil, fmt.Errorf("couldn't unmarshal data: %q: %s", string(out), err) } - targetFound := false for _, fs := range resp.FileSystems { - // check if the mount is propagated correctly. It should be set to shared. - if fs.Propagation != "shared" { - return true, fmt.Errorf("mount propagation for target %q is not enabled", target) - } - - // the mountpoint should match as well - if fs.Target == target { - targetFound = true + if fs.Target != target { + continue } + return &MountInfo{ + Target: fs.Target, + Source: fs.Source, + Propagation: fs.Propagation, + FsType: fs.FsType, + Options: fs.Options, + }, nil } - return targetFound, nil + return nil, nil } // Copyright note for the functions below. Originally taken from diff --git a/driver/node.go b/driver/node.go index c566fe35..ec3e0a10 100644 --- a/driver/node.go +++ b/driver/node.go @@ -19,8 +19,11 @@ package driver import ( "context" + "fmt" "os" + "path/filepath" "strconv" + "strings" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/sirupsen/logrus" @@ -140,18 +143,41 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe ll.Info("checking if stagingTargetPath is already mounted") - mounted, err := d.mounter.IsMounted(stagingTargetPath, ll) + mountInfo, err := d.mounter.GetMountInfo(stagingTargetPath, ll) if err != nil { ll.WithError(err).Error("unable to check if already mounted") return nil, err } + if mountInfo != nil && mountInfo.Propagation != "shared" { + return nil, fmt.Errorf("mount propagation for target %q is not enabled", stagingTargetPath) + } - if !mounted { + if mountInfo == nil { ll.Info("not mounted yet, mounting the volume for staging") if err := d.mounter.Mount(source, stagingTargetPath, fsType, luksContext, ll, options...); err != nil { return nil, status.Error(codes.Internal, err.Error()) } } else { + // Something is already mounted at the staging path. Verify it is + // mounted from the device we just resolved before declaring success — + // otherwise a stale mount left by an earlier (failed or racing) stage + // operation can be silently accepted, which is the same class of bug + // as the LUKS mapping reuse in luksOpen. + expected := source + if luksContext.EncryptionEnabled { + expected = "/dev/mapper/" + luksContext.VolumeName + } else if resolved, err := filepath.EvalSymlinks(source); err == nil { + // findmnt reports the kernel-resolved device, so compare against + // the canonical form. Fall back to the literal source on resolve + // failure — the mismatch will then surface as a loud error. + expected = resolved + } + + if strings.TrimSpace(mountInfo.Source) != expected { + return nil, status.Errorf(codes.FailedPrecondition, + "stage path %s is mounted from %q, expected %s, refusing to reuse stale mount", + stagingTargetPath, mountInfo.Source, expected) + } ll.Info("source device is already mounted to the stagingTargetPath path") } @@ -226,12 +252,12 @@ func (d *Driver) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolu }) ll.Info("node unstage volume called") - mounted, err := d.mounter.IsMounted(req.StagingTargetPath, ll) + mountInfo, err := d.mounter.GetMountInfo(req.StagingTargetPath, ll) if err != nil { return nil, err } - if mounted { + if mountInfo != nil { ll.Info("unmounting the staging target path") err := d.mounter.Unmount(req.StagingTargetPath, luksContext, ll) if err != nil { @@ -430,12 +456,15 @@ func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeS }) ll.Info("node get volume stats called") - mounted, err := d.mounter.IsMounted(volumePath, ll) + mountInfo, err := d.mounter.GetMountInfo(volumePath, ll) if err != nil { return nil, status.Errorf(codes.Internal, "failed to check if volume path %q is mounted: %s", volumePath, err) } + if mountInfo != nil && mountInfo.Propagation != "shared" { + return nil, status.Errorf(codes.Internal, "mount propagation for target %q is not enabled", volumePath) + } - if !mounted { + if mountInfo == nil { return nil, status.Errorf(codes.NotFound, "volume path %q is not mounted", volumePath) } @@ -527,12 +556,15 @@ func (d *Driver) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolume } } - mounted, err := d.mounter.IsMounted(volumePath, ll) + mountInfo, err := d.mounter.GetMountInfo(volumePath, ll) if err != nil { return nil, status.Errorf(codes.Internal, "NodeExpandVolume failed to check if volume path %q is mounted: %s", volumePath, err) } + if mountInfo != nil && mountInfo.Propagation != "shared" { + return nil, status.Errorf(codes.Internal, "NodeExpandVolume mount propagation for target %q is not enabled", volumePath) + } - if !mounted { + if mountInfo == nil { return nil, status.Errorf(codes.NotFound, "NodeExpandVolume volume path %q is not mounted", volumePath) }