smb: server: let smbdirect_map_sges_from_iter() truncate the message boundary

smbdirect_map_sges_from_iter() already handles the case that only
a limited number of sges are available. Its return value
is data_length and the remaining bytes in the iter are
remaining_data_length.

This is now much easier and will allow us to share
more code with the client soon.

Cc: Namjae Jeon <linkinjeon@kernel.org>
Cc: Steve French <smfrench@gmail.com>
Cc: Tom Talpey <tom@talpey.com>
Cc: linux-cifs@vger.kernel.org
Cc: samba-technical@lists.samba.org
Signed-off-by: Stefan Metzmacher <metze@samba.org>
Acked-by: Namjae Jeon <linkinjeon@kernel.org>
Signed-off-by: Steve French <stfrench@microsoft.com>
This commit is contained in:
Stefan Metzmacher
2025-10-17 17:58:16 +02:00
committed by Steve French
parent 0af87a0a31
commit da20536c50

View File

@@ -212,7 +212,7 @@ unsigned int get_smbd_max_read_write_size(struct ksmbd_transport *kt)
static int smb_direct_post_send_data(struct smbdirect_socket *sc,
struct smbdirect_send_batch *send_ctx,
struct iov_iter *iter,
size_t *remaining_data_length);
u32 remaining_data_length);
static void smb_direct_send_immediate_work(struct work_struct *work)
{
@@ -222,7 +222,7 @@ static void smb_direct_send_immediate_work(struct work_struct *work)
if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
return;
smb_direct_post_send_data(sc, NULL, NULL, NULL);
smb_direct_post_send_data(sc, NULL, NULL, 0);
}
static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
@@ -805,23 +805,27 @@ static int post_sendmsg(struct smbdirect_socket *sc,
static int smb_direct_post_send_data(struct smbdirect_socket *sc,
struct smbdirect_send_batch *send_ctx,
struct iov_iter *iter,
size_t *_remaining_data_length)
u32 remaining_data_length)
{
const struct smbdirect_socket_parameters *sp = &sc->parameters;
int ret;
struct smbdirect_send_io *msg;
struct smbdirect_data_transfer *packet;
size_t header_length;
u32 remaining_data_length = 0;
u32 data_length = 0;
struct smbdirect_send_batch _send_ctx;
u16 new_credits;
if (iter) {
header_length = sizeof(struct smbdirect_data_transfer);
if (WARN_ON_ONCE(remaining_data_length == 0 ||
iov_iter_count(iter) > remaining_data_length))
return -EINVAL;
} else {
/* If this is a packet without payload, don't send padding */
header_length = offsetof(struct smbdirect_data_transfer, padding);
if (WARN_ON_ONCE(remaining_data_length))
return -EINVAL;
}
if (!send_ctx) {
@@ -858,14 +862,6 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
new_credits = smbdirect_connection_grant_recv_credits(sc);
}
if (iter)
data_length = iov_iter_count(iter);
if (_remaining_data_length) {
*_remaining_data_length -= data_length;
remaining_data_length = *_remaining_data_length;
}
msg = smbdirect_connection_alloc_send_io(sc);
if (IS_ERR(msg)) {
ret = PTR_ERR(msg);
@@ -894,14 +890,14 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
.local_dma_lkey = sc->ib.pd->local_dma_lkey,
.direction = DMA_TO_DEVICE,
};
size_t payload_len = umin(iov_iter_count(iter),
sp->max_send_size - sizeof(*packet));
ret = smbdirect_map_sges_from_iter(iter, data_length, &extract);
ret = smbdirect_map_sges_from_iter(iter, payload_len, &extract);
if (ret < 0)
goto err;
if (WARN_ON_ONCE(ret != data_length)) {
ret = -EIO;
goto err;
}
data_length = ret;
remaining_data_length -= data_length;
msg->num_sge = extract.num_sge;
}
@@ -970,13 +966,9 @@ static int smb_direct_writev(struct ksmbd_transport *t,
struct smb_direct_transport *st = SMBD_TRANS(t);
struct smbdirect_socket *sc = &st->socket;
struct smbdirect_socket_parameters *sp = &sc->parameters;
size_t remaining_data_length;
size_t iov_idx;
size_t iov_ofs;
size_t max_iov_size = sp->max_send_size -
sizeof(struct smbdirect_data_transfer);
int ret;
struct smbdirect_send_batch send_ctx;
struct iov_iter iter;
int error = 0;
if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
@@ -985,112 +977,31 @@ static int smb_direct_writev(struct ksmbd_transport *t,
//FIXME: skip RFC1002 header..
if (WARN_ON_ONCE(niovs <= 1 || iov[0].iov_len != 4))
return -EINVAL;
buflen -= 4;
iov_idx = 1;
iov_ofs = 0;
iov_iter_kvec(&iter, ITER_SOURCE, iov, niovs, buflen);
iov_iter_advance(&iter, 4);
remaining_data_length = buflen;
ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);
/*
* The size must fit into the negotiated
* fragmented send size.
*/
if (iov_iter_count(&iter) > sp->max_fragmented_send_size)
return -EMSGSIZE;
ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%zu\n",
iov_iter_count(&iter));
smb_direct_send_ctx_init(&send_ctx, need_invalidate, remote_key);
while (remaining_data_length) {
struct kvec vecs[SMBDIRECT_SEND_IO_MAX_SGE - 1]; /* minus smbdirect hdr */
size_t possible_bytes = max_iov_size;
size_t possible_vecs;
size_t bytes = 0;
size_t nvecs = 0;
struct iov_iter iter;
/*
* For the last message remaining_data_length should be
* have been 0 already!
*/
if (WARN_ON_ONCE(iov_idx >= niovs)) {
error = -EINVAL;
goto done;
}
/*
* We have 2 factors which limit the arguments we pass
* to smb_direct_post_send_data():
*
* 1. The number of supported sges for the send,
* while one is reserved for the smbdirect header.
* And we currently need one SGE per page.
* 2. The number of negotiated payload bytes per send.
*/
possible_vecs = min_t(size_t, ARRAY_SIZE(vecs), niovs - iov_idx);
while (iov_idx < niovs && possible_vecs && possible_bytes) {
struct kvec *v = &vecs[nvecs];
int page_count;
v->iov_base = ((u8 *)iov[iov_idx].iov_base) + iov_ofs;
v->iov_len = min_t(size_t,
iov[iov_idx].iov_len - iov_ofs,
possible_bytes);
page_count = smbdirect_get_buf_page_count(v->iov_base, v->iov_len);
if (page_count > possible_vecs) {
/*
* If the number of pages in the buffer
* is to much (because we currently require
* one SGE per page), we need to limit the
* length.
*
* We know possible_vecs is at least 1,
* so we always keep the first page.
*
* We need to calculate the number extra
* pages (epages) we can also keep.
*
* We calculate the number of bytes in the
* first page (fplen), this should never be
* larger than v->iov_len because page_count is
* at least 2, but adding a limitation feels
* better.
*
* Then we calculate the number of bytes (elen)
* we can keep for the extra pages.
*/
size_t epages = possible_vecs - 1;
size_t fpofs = offset_in_page(v->iov_base);
size_t fplen = min_t(size_t, PAGE_SIZE - fpofs, v->iov_len);
size_t elen = min_t(size_t, v->iov_len - fplen, epages*PAGE_SIZE);
v->iov_len = fplen + elen;
page_count = smbdirect_get_buf_page_count(v->iov_base, v->iov_len);
if (WARN_ON_ONCE(page_count > possible_vecs)) {
/*
* Something went wrong in the above
* logic...
*/
error = -EINVAL;
goto done;
}
}
possible_vecs -= page_count;
nvecs += 1;
possible_bytes -= v->iov_len;
bytes += v->iov_len;
iov_ofs += v->iov_len;
if (iov_ofs >= iov[iov_idx].iov_len) {
iov_idx += 1;
iov_ofs = 0;
}
}
iov_iter_kvec(&iter, ITER_SOURCE, vecs, nvecs, bytes);
ret = smb_direct_post_send_data(sc, &send_ctx,
&iter, &remaining_data_length);
while (iov_iter_count(&iter)) {
ret = smb_direct_post_send_data(sc,
&send_ctx,
&iter,
iov_iter_count(&iter));
if (unlikely(ret)) {
error = ret;
goto done;
break;
}
}
done:
ret = smb_direct_flush_send_list(sc, &send_ctx, true);
if (unlikely(!ret && error))
ret = error;