1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
67//! Types for the [Proof Key for Code Exchange].
8//!
9//! [Proof Key for Code Exchange]: https://www.rfc-editor.org/rfc/rfc7636
1011use std::borrow::Cow;
1213use base64ct::{Base64UrlUnpadded, Encoding};
14use mas_iana::oauth::PkceCodeChallengeMethod;
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17use thiserror::Error;
1819/// Errors that can occur when verifying a code challenge.
20#[derive(Debug, Error, PartialEq, Eq)]
21pub enum CodeChallengeError {
22/// The code verifier should be at least 43 characters long.
23#[error("code_verifier should be at least 43 characters long")]
24TooShort,
2526/// The code verifier should be at most 128 characters long.
27#[error("code_verifier should be at most 128 characters long")]
28TooLong,
2930/// The code verifier contains invalid characters.
31#[error("code_verifier contains invalid characters")]
32InvalidCharacters,
3334/// The challenge verification failed.
35#[error("challenge verification failed")]
36VerificationFailed,
3738/// The challenge method is unsupported.
39#[error("unknown challenge method")]
40UnknownChallengeMethod,
41}
4243fn validate_verifier(verifier: &str) -> Result<(), CodeChallengeError> {
44if verifier.len() < 43 {
45return Err(CodeChallengeError::TooShort);
46 }
4748if verifier.len() > 128 {
49return Err(CodeChallengeError::TooLong);
50 }
5152if !verifier
53 .chars()
54 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~')
55 {
56return Err(CodeChallengeError::InvalidCharacters);
57 }
5859Ok(())
60}
6162/// Helper trait to compute and verify code challenges.
63pub trait CodeChallengeMethodExt {
64/// Compute the challenge for a given verifier
65 ///
66 /// # Errors
67 ///
68 /// Returns an error if the verifier did not adhere to the rules defined by
69 /// the RFC in terms of length and allowed characters
70fn compute_challenge<'a>(&self, verifier: &'a str) -> Result<Cow<'a, str>, CodeChallengeError>;
7172/// Verify that a given verifier is valid for the given challenge
73 ///
74 /// # Errors
75 ///
76 /// Returns an error if the verifier did not match the challenge, or if the
77 /// verifier did not adhere to the rules defined by the RFC in terms of
78 /// length and allowed characters
79fn verify(&self, challenge: &str, verifier: &str) -> Result<(), CodeChallengeError>
80where
81Self: Sized,
82 {
83if self.compute_challenge(verifier)? == challenge {
84Ok(())
85 } else {
86Err(CodeChallengeError::VerificationFailed)
87 }
88 }
89}
9091impl CodeChallengeMethodExt for PkceCodeChallengeMethod {
92fn compute_challenge<'a>(&self, verifier: &'a str) -> Result<Cow<'a, str>, CodeChallengeError> {
93 validate_verifier(verifier)?;
9495let challenge = match self {
96Self::Plain => verifier.into(),
97Self::S256 => {
98let mut hasher = Sha256::new();
99 hasher.update(verifier.as_bytes());
100let hash = hasher.finalize();
101let verifier = Base64UrlUnpadded::encode_string(&hash);
102 verifier.into()
103 }
104_ => return Err(CodeChallengeError::UnknownChallengeMethod),
105 };
106107Ok(challenge)
108 }
109}
110111/// The code challenge data added to an authorization request.
112#[derive(Clone, Serialize, Deserialize)]
113pub struct AuthorizationRequest {
114/// The code challenge method.
115pub code_challenge_method: PkceCodeChallengeMethod,
116117/// The code challenge computed from the verifier and the method.
118pub code_challenge: String,
119}
120121/// The code challenge data added to a token request.
122#[derive(Clone, Serialize, Deserialize)]
123pub struct TokenRequest {
124/// The code challenge verifier.
125pub code_challenge_verifier: String,
126}
127128#[cfg(test)]
129mod tests {
130use super::*;
131132#[test]
133fn test_pkce_verification() {
134use PkceCodeChallengeMethod::{Plain, S256};
135// This challenge comes from the RFC7636 appendices
136let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
137138assert!(
139 S256.verify(challenge, "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")
140 .is_ok()
141 );
142143assert!(Plain.verify(challenge, challenge).is_ok());
144145assert_eq!(
146 S256.verify(challenge, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"),
147Err(CodeChallengeError::VerificationFailed),
148 );
149150assert_eq!(
151 S256.verify(challenge, "tooshort"),
152Err(CodeChallengeError::TooShort),
153 );
154155assert_eq!(
156 S256.verify(challenge, "toolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolong"),
157Err(CodeChallengeError::TooLong),
158 );
159160assert_eq!(
161 S256.verify(
162 challenge,
163"this is long enough but has invalid characters in it"
164),
165Err(CodeChallengeError::InvalidCharacters),
166 );
167 }
168}