1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
67use std::{num::NonZeroU32, time::Duration};
89use camino::Utf8PathBuf;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
1314use super::ConfigurationSection;
15use crate::schema;
1617#[allow(clippy::unnecessary_wraps)]
18fn default_connection_string() -> Option<String> {
19Some("postgresql://".to_owned())
20}
2122fn default_max_connections() -> NonZeroU32 {
23 NonZeroU32::new(10).unwrap()
24}
2526fn default_connect_timeout() -> Duration {
27 Duration::from_secs(30)
28}
2930#[allow(clippy::unnecessary_wraps)]
31fn default_idle_timeout() -> Option<Duration> {
32Some(Duration::from_secs(10 * 60))
33}
3435#[allow(clippy::unnecessary_wraps)]
36fn default_max_lifetime() -> Option<Duration> {
37Some(Duration::from_secs(30 * 60))
38}
3940impl Default for DatabaseConfig {
41fn default() -> Self {
42Self {
43 uri: default_connection_string(),
44 host: None,
45 port: None,
46 socket: None,
47 username: None,
48 password: None,
49 database: None,
50 ssl_mode: None,
51 ssl_ca: None,
52 ssl_ca_file: None,
53 ssl_certificate: None,
54 ssl_certificate_file: None,
55 ssl_key: None,
56 ssl_key_file: None,
57 max_connections: default_max_connections(),
58 min_connections: Default::default(),
59 connect_timeout: default_connect_timeout(),
60 idle_timeout: default_idle_timeout(),
61 max_lifetime: default_max_lifetime(),
62 }
63 }
64}
6566/// Options for controlling the level of protection provided for PostgreSQL SSL
67/// connections.
68#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
69#[serde(rename_all = "kebab-case")]
70pub enum PgSslMode {
71/// Only try a non-SSL connection.
72Disable,
7374/// First try a non-SSL connection; if that fails, try an SSL connection.
75Allow,
7677/// First try an SSL connection; if that fails, try a non-SSL connection.
78Prefer,
7980/// Only try an SSL connection. If a root CA file is present, verify the
81 /// connection in the same way as if `VerifyCa` was specified.
82Require,
8384/// Only try an SSL connection, and verify that the server certificate is
85 /// issued by a trusted certificate authority (CA).
86VerifyCa,
8788/// Only try an SSL connection; verify that the server certificate is issued
89 /// by a trusted CA and that the requested server host name matches that
90 /// in the certificate.
91VerifyFull,
92}
9394/// Database connection configuration
95#[serde_as]
96#[derive(Debug, Serialize, Deserialize, JsonSchema)]
97pub struct DatabaseConfig {
98/// Connection URI
99 ///
100 /// This must not be specified if `host`, `port`, `socket`, `username`,
101 /// `password`, or `database` are specified.
102#[serde(skip_serializing_if = "Option::is_none")]
103 #[schemars(url, default = "default_connection_string")]
104pub uri: Option<String>,
105106/// Name of host to connect to
107 ///
108 /// This must not be specified if `uri` is specified.
109#[serde(skip_serializing_if = "Option::is_none")]
110 #[schemars(with = "Option::<schema::Hostname>")]
111pub host: Option<String>,
112113/// Port number to connect at the server host
114 ///
115 /// This must not be specified if `uri` is specified.
116#[serde(skip_serializing_if = "Option::is_none")]
117 #[schemars(range(min = 1, max = 65535))]
118pub port: Option<u16>,
119120/// Directory containing the UNIX socket to connect to
121 ///
122 /// This must not be specified if `uri` is specified.
123#[serde(skip_serializing_if = "Option::is_none")]
124 #[schemars(with = "Option<String>")]
125pub socket: Option<Utf8PathBuf>,
126127/// PostgreSQL user name to connect as
128 ///
129 /// This must not be specified if `uri` is specified.
130#[serde(skip_serializing_if = "Option::is_none")]
131pub username: Option<String>,
132133/// Password to be used if the server demands password authentication
134 ///
135 /// This must not be specified if `uri` is specified.
136#[serde(skip_serializing_if = "Option::is_none")]
137pub password: Option<String>,
138139/// The database name
140 ///
141 /// This must not be specified if `uri` is specified.
142#[serde(skip_serializing_if = "Option::is_none")]
143pub database: Option<String>,
144145/// How to handle SSL connections
146#[serde(skip_serializing_if = "Option::is_none")]
147pub ssl_mode: Option<PgSslMode>,
148149/// The PEM-encoded root certificate for SSL connections
150 ///
151 /// This must not be specified if the `ssl_ca_file` option is specified.
152#[serde(skip_serializing_if = "Option::is_none")]
153pub ssl_ca: Option<String>,
154155/// Path to the root certificate for SSL connections
156 ///
157 /// This must not be specified if the `ssl_ca` option is specified.
158#[serde(skip_serializing_if = "Option::is_none")]
159 #[schemars(with = "Option<String>")]
160pub ssl_ca_file: Option<Utf8PathBuf>,
161162/// The PEM-encoded client certificate for SSL connections
163 ///
164 /// This must not be specified if the `ssl_certificate_file` option is
165 /// specified.
166#[serde(skip_serializing_if = "Option::is_none")]
167pub ssl_certificate: Option<String>,
168169/// Path to the client certificate for SSL connections
170 ///
171 /// This must not be specified if the `ssl_certificate` option is specified.
172#[serde(skip_serializing_if = "Option::is_none")]
173 #[schemars(with = "Option<String>")]
174pub ssl_certificate_file: Option<Utf8PathBuf>,
175176/// The PEM-encoded client key for SSL connections
177 ///
178 /// This must not be specified if the `ssl_key_file` option is specified.
179#[serde(skip_serializing_if = "Option::is_none")]
180pub ssl_key: Option<String>,
181182/// Path to the client key for SSL connections
183 ///
184 /// This must not be specified if the `ssl_key` option is specified.
185#[serde(skip_serializing_if = "Option::is_none")]
186 #[schemars(with = "Option<String>")]
187pub ssl_key_file: Option<Utf8PathBuf>,
188189/// Set the maximum number of connections the pool should maintain
190#[serde(default = "default_max_connections")]
191pub max_connections: NonZeroU32,
192193/// Set the minimum number of connections the pool should maintain
194#[serde(default)]
195pub min_connections: u32,
196197/// Set the amount of time to attempt connecting to the database
198#[schemars(with = "u64")]
199 #[serde(default = "default_connect_timeout")]
200 #[serde_as(as = "serde_with::DurationSeconds<u64>")]
201pub connect_timeout: Duration,
202203/// Set a maximum idle duration for individual connections
204#[schemars(with = "Option<u64>")]
205 #[serde(
206 default = "default_idle_timeout",
207 skip_serializing_if = "Option::is_none"
208)]
209 #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
210pub idle_timeout: Option<Duration>,
211212/// Set the maximum lifetime of individual connections
213#[schemars(with = "u64")]
214 #[serde(
215 default = "default_max_lifetime",
216 skip_serializing_if = "Option::is_none"
217)]
218 #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
219pub max_lifetime: Option<Duration>,
220}
221222impl ConfigurationSection for DatabaseConfig {
223const PATH: Option<&'static str> = Some("database");
224225fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> {
226let metadata = figment.find_metadata(Self::PATH.unwrap());
227let annotate = |mut error: figment::Error| {
228 error.metadata = metadata.cloned();
229 error.profile = Some(figment::Profile::Default);
230 error.path = vec![Self::PATH.unwrap().to_owned()];
231Err(error)
232 };
233234// Check that the user did not specify both `uri` and the split options at the
235 // same time
236let has_split_options = self.host.is_some()
237 || self.port.is_some()
238 || self.socket.is_some()
239 || self.username.is_some()
240 || self.password.is_some()
241 || self.database.is_some();
242243if self.uri.is_some() && has_split_options {
244return annotate(figment::error::Error::from(
245"uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(),
246 ));
247 }
248249if self.ssl_ca.is_some() && self.ssl_ca_file.is_some() {
250return annotate(figment::error::Error::from(
251"ssl_ca must not be specified if ssl_ca_file is specified".to_owned(),
252 ));
253 }
254255if self.ssl_certificate.is_some() && self.ssl_certificate_file.is_some() {
256return annotate(figment::error::Error::from(
257"ssl_certificate must not be specified if ssl_certificate_file is specified"
258.to_owned(),
259 ));
260 }
261262if self.ssl_key.is_some() && self.ssl_key_file.is_some() {
263return annotate(figment::error::Error::from(
264"ssl_key must not be specified if ssl_key_file is specified".to_owned(),
265 ));
266 }
267268if (self.ssl_key.is_some() || self.ssl_key_file.is_some())
269 ^ (self.ssl_certificate.is_some() || self.ssl_certificate_file.is_some())
270 {
271return annotate(figment::error::Error::from(
272"both a ssl_certificate and a ssl_key must be set at the same time or none of them"
273.to_owned(),
274 ));
275 }
276277Ok(())
278 }
279}
280#[cfg(test)]
281mod tests {
282use figment::{
283 Figment, Jail,
284 providers::{Format, Yaml},
285 };
286287use super::*;
288289#[test]
290fn load_config() {
291 Jail::expect_with(|jail| {
292 jail.create_file(
293"config.yaml",
294r"
295 database:
296 uri: postgresql://user:password@host/database
297 ",
298 )?;
299300let config = Figment::new()
301 .merge(Yaml::file("config.yaml"))
302 .extract_inner::<DatabaseConfig>("database")?;
303304assert_eq!(
305 config.uri.as_deref(),
306Some("postgresql://user:password@host/database")
307 );
308309Ok(())
310 });
311 }
312}