diff --git a/net/core/netdev-genl.c b/net/core/netdev-genl.c index 739598d34657..470fabbeacd9 100644 --- a/net/core/netdev-genl.c +++ b/net/core/netdev-genl.c @@ -908,6 +908,30 @@ static int netdev_nl_read_rxq_bitmap(struct genl_info *info, return 0; } +static struct device * +netdev_nl_get_dma_dev(struct net_device *netdev, unsigned long *rxq_bitmap, + struct netlink_ext_ack *extack) +{ + struct device *dma_dev = NULL; + u32 rxq_idx, prev_rxq_idx; + + for_each_set_bit(rxq_idx, rxq_bitmap, netdev->real_num_rx_queues) { + struct device *rxq_dma_dev; + + rxq_dma_dev = netdev_queue_get_dma_dev(netdev, rxq_idx); + if (dma_dev && rxq_dma_dev != dma_dev) { + NL_SET_ERR_MSG_FMT(extack, "DMA device mismatch between queue %u and %u (multi-PF device?)", + rxq_idx, prev_rxq_idx); + return ERR_PTR(-EOPNOTSUPP); + } + + dma_dev = rxq_dma_dev; + prev_rxq_idx = rxq_idx; + } + + return dma_dev; +} + int netdev_nl_bind_rx_doit(struct sk_buff *skb, struct genl_info *info) { struct net_devmem_dmabuf_binding *binding; @@ -971,7 +995,12 @@ int netdev_nl_bind_rx_doit(struct sk_buff *skb, struct genl_info *info) if (err) goto err_rxq_bitmap; - dma_dev = netdev_queue_get_dma_dev(netdev, 0); + dma_dev = netdev_nl_get_dma_dev(netdev, rxq_bitmap, info->extack); + if (IS_ERR(dma_dev)) { + err = PTR_ERR(dma_dev); + goto err_rxq_bitmap; + } + binding = net_devmem_bind_dmabuf(netdev, dma_dev, DMA_FROM_DEVICE, dmabuf_fd, priv, info->extack); if (IS_ERR(binding)) {