diff --git a/drivers/block/nbd.c b/drivers/block/nbd.c index fe63f3c55d0d..0e2180e910c4 100644 --- a/drivers/block/nbd.c +++ b/drivers/block/nbd.c @@ -1238,6 +1238,42 @@ static struct socket *nbd_get_socket(struct nbd_device *nbd, unsigned long fd, return sock; } +#ifdef CONFIG_DEBUG_LOCK_ALLOC +static struct lock_class_key nbd_key[3]; +static struct lock_class_key nbd_slock_key[3]; + +static void nbd_reclassify_socket(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (WARN_ON_ONCE(!sock_allow_reclassification(sk))) + return; + + switch (sk->sk_family) { + case AF_INET: + sock_lock_init_class_and_name(sk, "slock-AF_INET-NBD", + &nbd_slock_key[0], + "sk_lock-AF_INET-NBD", + &nbd_key[0]); + break; + case AF_INET6: + sock_lock_init_class_and_name(sk, "slock-AF_INET6-NBD", + &nbd_slock_key[1], + "sk_lock-AF_INET6-NBD", + &nbd_key[1]); + break; + case AF_UNIX: + sock_lock_init_class_and_name(sk, "slock-AF_UNIX-NBD", + &nbd_slock_key[2], + "sk_lock-AF_UNIX-NBD", + &nbd_key[2]); + break; + } +} +#else +static inline void nbd_reclassify_socket(struct socket *sock) {} +#endif + static int nbd_add_socket(struct nbd_device *nbd, unsigned long arg, bool netlink) { @@ -1254,6 +1290,7 @@ static int nbd_add_socket(struct nbd_device *nbd, unsigned long arg, sock = nbd_get_socket(nbd, arg, &err); if (!sock) return err; + nbd_reclassify_socket(sock); /* * We need to make sure we don't get any errant requests while we're