diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h index 649ea9d2ae9b..0a8875e0709b 100644 --- a/net/mac80211/ieee80211_i.h +++ b/net/mac80211/ieee80211_i.h @@ -430,7 +430,7 @@ struct ieee80211_mgd_auth_data { u8 ap_addr[ETH_ALEN] __aligned(2); - u16 sae_trans, sae_status; + u16 trans, status; size_t data_len; u8 data[]; }; diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c index 977303fdfd9f..0c31a0602ea9 100644 --- a/net/mac80211/mlme.c +++ b/net/mac80211/mlme.c @@ -4911,6 +4911,7 @@ static void ieee80211_rx_mgmt_auth(struct ieee80211_sub_if_data *sdata, case WLAN_AUTH_FILS_SK: case WLAN_AUTH_FILS_SK_PFS: case WLAN_AUTH_FILS_PK: + case WLAN_AUTH_EPPKE: break; case WLAN_AUTH_SHARED_KEY: if (ifmgd->auth_data->expected_transaction != 4) { @@ -8277,6 +8278,12 @@ static int ieee80211_auth(struct ieee80211_sub_if_data *sdata) if (WARN_ON_ONCE(!auth_data)) return -EINVAL; + if (auth_data->algorithm == WLAN_AUTH_EPPKE && + ieee80211_vif_is_mld(&sdata->vif) && + !cfg80211_find_ext_elem(WLAN_EID_EXT_EHT_MULTI_LINK, + auth_data->data, auth_data->data_len)) + return -EINVAL; + auth_data->tries++; if (auth_data->tries > IEEE80211_AUTH_MAX_TRIES) { @@ -8305,9 +8312,12 @@ static int ieee80211_auth(struct ieee80211_sub_if_data *sdata) auth_data->expected_transaction = 2; if (auth_data->algorithm == WLAN_AUTH_SAE) { - trans = auth_data->sae_trans; - status = auth_data->sae_status; + trans = auth_data->trans; + status = auth_data->status; auth_data->expected_transaction = trans; + } else if (auth_data->algorithm == WLAN_AUTH_EPPKE) { + trans = auth_data->trans; + status = auth_data->status; } if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) @@ -9222,6 +9232,9 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, case NL80211_AUTHTYPE_FILS_PK: auth_alg = WLAN_AUTH_FILS_PK; break; + case NL80211_AUTHTYPE_EPPKE: + auth_alg = WLAN_AUTH_EPPKE; + break; default: return -EOPNOTSUPP; } @@ -9246,12 +9259,14 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, auth_data->link_id = req->link_id; if (req->auth_data_len >= 4) { - if (req->auth_type == NL80211_AUTHTYPE_SAE) { + if (req->auth_type == NL80211_AUTHTYPE_SAE || + req->auth_type == NL80211_AUTHTYPE_EPPKE) { __le16 *pos = (__le16 *) req->auth_data; - auth_data->sae_trans = le16_to_cpu(pos[0]); - auth_data->sae_status = le16_to_cpu(pos[1]); + auth_data->trans = le16_to_cpu(pos[0]); + auth_data->status = le16_to_cpu(pos[1]); } + memcpy(auth_data->data, req->auth_data + 4, req->auth_data_len - 4); auth_data->data_len += req->auth_data_len - 4; @@ -9302,7 +9317,11 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, * out SAE Confirm. */ if (cont_auth && req->auth_type == NL80211_AUTHTYPE_SAE && - auth_data->peer_confirmed && auth_data->sae_trans == 2) + auth_data->peer_confirmed && auth_data->trans == 2) + ieee80211_mark_sta_auth(sdata); + + if (cont_auth && req->auth_type == NL80211_AUTHTYPE_EPPKE && + auth_data->trans == 3) ieee80211_mark_sta_auth(sdata); if (ifmgd->associated) {