diff --git a/fs/smb/server/server.c b/fs/smb/server/server.c index 3cea16050e4f..bedc8390b6db 100644 --- a/fs/smb/server/server.c +++ b/fs/smb/server/server.c @@ -95,7 +95,7 @@ static inline int check_conn_state(struct ksmbd_work *work) if (ksmbd_conn_exiting(work->conn) || ksmbd_conn_need_reconnect(work->conn)) { - rsp_hdr = work->response_buf; + rsp_hdr = smb2_get_msg(work->response_buf); rsp_hdr->Status.CifsError = STATUS_CONNECTION_DISCONNECTED; return 1; } diff --git a/fs/smb/server/smb_common.c b/fs/smb/server/smb_common.c index b23203a1c286..6d7b4449276b 100644 --- a/fs/smb/server/smb_common.c +++ b/fs/smb/server/smb_common.c @@ -140,7 +140,7 @@ int ksmbd_verify_smb_message(struct ksmbd_work *work) if (smb2_hdr->ProtocolId == SMB2_PROTO_NUMBER) return ksmbd_smb2_check_message(work); - hdr = work->request_buf; + hdr = smb2_get_msg(work->request_buf); if (*(__le32 *)hdr->Protocol == SMB1_PROTO_NUMBER && hdr->Command == SMB_COM_NEGOTIATE) { work->conn->outstanding_credits++; @@ -278,15 +278,14 @@ static int ksmbd_negotiate_smb_dialect(void *buf) req->DialectCount); } - proto = *(__le32 *)((struct smb_hdr *)buf)->Protocol; if (proto == SMB1_PROTO_NUMBER) { struct smb_negotiate_req *req; - req = (struct smb_negotiate_req *)buf; + req = (struct smb_negotiate_req *)smb2_get_msg(buf); if (le16_to_cpu(req->ByteCount) < 2) goto err_out; - if (offsetof(struct smb_negotiate_req, DialectsArray) - 4 + + if (offsetof(struct smb_negotiate_req, DialectsArray) + le16_to_cpu(req->ByteCount) > smb_buf_length) { goto err_out; } @@ -320,8 +319,8 @@ static u16 get_smb1_cmd_val(struct ksmbd_work *work) */ static int init_smb1_rsp_hdr(struct ksmbd_work *work) { - struct smb_hdr *rsp_hdr = (struct smb_hdr *)work->response_buf; - struct smb_hdr *rcv_hdr = (struct smb_hdr *)work->request_buf; + struct smb_hdr *rsp_hdr = (struct smb_hdr *)smb2_get_msg(work->response_buf); + struct smb_hdr *rcv_hdr = (struct smb_hdr *)smb2_get_msg(work->request_buf); rsp_hdr->Command = SMB_COM_NEGOTIATE; *(__le32 *)rsp_hdr->Protocol = SMB1_PROTO_NUMBER; @@ -412,9 +411,10 @@ static int init_smb1_server(struct ksmbd_conn *conn) int ksmbd_init_smb_server(struct ksmbd_conn *conn) { + struct smb_hdr *rcv_hdr = (struct smb_hdr *)smb2_get_msg(conn->request_buf); __le32 proto; - proto = *(__le32 *)((struct smb_hdr *)conn->request_buf)->Protocol; + proto = *(__le32 *)rcv_hdr->Protocol; if (conn->need_neg == false) { if (proto == SMB1_PROTO_NUMBER) return -EINVAL; @@ -572,12 +572,12 @@ static int __smb2_negotiate(struct ksmbd_conn *conn) static int smb_handle_negotiate(struct ksmbd_work *work) { - struct smb_negotiate_rsp *neg_rsp = work->response_buf; + struct smb_negotiate_rsp *neg_rsp = smb2_get_msg(work->response_buf); ksmbd_debug(SMB, "Unsupported SMB1 protocol\n"); - if (ksmbd_iov_pin_rsp(work, (void *)neg_rsp + 4, - sizeof(struct smb_negotiate_rsp) - 4)) + if (ksmbd_iov_pin_rsp(work, (void *)neg_rsp, + sizeof(struct smb_negotiate_rsp))) return -ENOMEM; neg_rsp->hdr.Status.CifsError = STATUS_SUCCESS;