Skip to main content

zeroclaw_gateway/
tls.rs

1//! TLS and mutual TLS (mTLS) support for the gateway server.
2//!
3//! Builds a [`rustls::ServerConfig`] from the gateway TLS configuration,
4//! optionally requiring client certificates verified against a trusted CA
5//! with optional certificate pinning (SHA-256 fingerprint matching).
6
7use anyhow::{Context, Result};
8use rustls::RootCertStore;
9use rustls::pki_types::{CertificateDer, PrivateKeyDer};
10use rustls::server::WebPkiClientVerifier;
11use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
12use sha2::{Digest, Sha256};
13use std::sync::Arc;
14use tokio_rustls::TlsAcceptor;
15use zeroclaw_config::schema::{GatewayClientAuthConfig, GatewayTlsConfig};
16
17/// Build a [`TlsAcceptor`] from the gateway TLS configuration.
18pub fn build_tls_acceptor(config: &GatewayTlsConfig) -> Result<TlsAcceptor> {
19    let server_config = build_server_config(config)?;
20    Ok(TlsAcceptor::from(Arc::new(server_config)))
21}
22
23/// Build a [`rustls::ServerConfig`] from the gateway TLS configuration.
24pub fn build_server_config(config: &GatewayTlsConfig) -> Result<rustls::ServerConfig> {
25    let certs = load_certs(&config.cert_path).with_context(|| {
26        format!(
27            "failed to load server certificate from {}",
28            config.cert_path
29        )
30    })?;
31    let key = load_private_key(&config.key_path)
32        .with_context(|| format!("failed to load private key from {}", config.key_path))?;
33
34    let client_auth_config = config.client_auth.as_ref().filter(|ca| ca.enabled);
35
36    let builder = rustls::ServerConfig::builder();
37
38    let server_config = if let Some(client_auth) = client_auth_config {
39        let verifier = build_client_verifier(client_auth)
40            .context("failed to build client certificate verifier")?;
41        builder
42            .with_client_cert_verifier(verifier)
43            .with_single_cert(certs, key)
44            .context("invalid server certificate or key")?
45    } else {
46        builder
47            .with_no_client_auth()
48            .with_single_cert(certs, key)
49            .context("invalid server certificate or key")?
50    };
51
52    Ok(server_config)
53}
54
55/// Build a client certificate verifier from the client auth configuration.
56fn build_client_verifier(config: &GatewayClientAuthConfig) -> Result<Arc<dyn ClientCertVerifier>> {
57    let ca_certs = load_certs(&config.ca_cert_path)
58        .with_context(|| format!("failed to load CA certificate from {}", config.ca_cert_path))?;
59
60    let mut root_store = RootCertStore::empty();
61    for cert in &ca_certs {
62        root_store
63            .add(cert.clone())
64            .context("failed to add CA certificate to root store")?;
65    }
66
67    let base_verifier = if config.require_client_cert {
68        WebPkiClientVerifier::builder(Arc::new(root_store))
69            .build()
70            .context("failed to build WebPKI client verifier")?
71    } else {
72        WebPkiClientVerifier::builder(Arc::new(root_store))
73            .allow_unauthenticated()
74            .build()
75            .context("failed to build WebPKI client verifier (optional auth)")?
76    };
77
78    if config.pinned_certs.is_empty() {
79        Ok(base_verifier)
80    } else {
81        let normalized: Vec<String> = config
82            .pinned_certs
83            .iter()
84            .map(|fp| fp.replace(':', "").to_lowercase())
85            .collect();
86        Ok(Arc::new(PinnedCertVerifier {
87            inner: base_verifier,
88            pinned_fingerprints: normalized,
89        }))
90    }
91}
92
93/// Compute the SHA-256 fingerprint of a DER-encoded certificate.
94pub fn cert_sha256_fingerprint(cert_der: &[u8]) -> String {
95    let mut hasher = Sha256::new();
96    hasher.update(cert_der);
97    let hash = hasher.finalize();
98    hex::encode(hash)
99}
100
101/// A client certificate verifier that delegates to a base verifier and then
102/// checks that the presented certificate matches one of the pinned SHA-256
103/// fingerprints.
104#[derive(Debug)]
105struct PinnedCertVerifier {
106    inner: Arc<dyn ClientCertVerifier>,
107    pinned_fingerprints: Vec<String>,
108}
109
110impl ClientCertVerifier for PinnedCertVerifier {
111    fn offer_client_auth(&self) -> bool {
112        self.inner.offer_client_auth()
113    }
114
115    fn client_auth_mandatory(&self) -> bool {
116        self.inner.client_auth_mandatory()
117    }
118
119    fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
120        self.inner.root_hint_subjects()
121    }
122
123    fn verify_client_cert(
124        &self,
125        end_entity: &CertificateDer<'_>,
126        intermediates: &[CertificateDer<'_>],
127        now: rustls::pki_types::UnixTime,
128    ) -> std::result::Result<ClientCertVerified, rustls::Error> {
129        // First, run the standard WebPKI verification.
130        self.inner
131            .verify_client_cert(end_entity, intermediates, now)?;
132
133        // Then check the fingerprint against the pinned set.
134        let fingerprint = cert_sha256_fingerprint(end_entity.as_ref());
135        if self.pinned_fingerprints.contains(&fingerprint) {
136            Ok(ClientCertVerified::assertion())
137        } else {
138            Err(rustls::Error::General(format!(
139                "client certificate fingerprint {fingerprint} is not in the pinned set"
140            )))
141        }
142    }
143
144    fn verify_tls12_signature(
145        &self,
146        message: &[u8],
147        cert: &CertificateDer<'_>,
148        dss: &rustls::DigitallySignedStruct,
149    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
150        self.inner.verify_tls12_signature(message, cert, dss)
151    }
152
153    fn verify_tls13_signature(
154        &self,
155        message: &[u8],
156        cert: &CertificateDer<'_>,
157        dss: &rustls::DigitallySignedStruct,
158    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
159        self.inner.verify_tls13_signature(message, cert, dss)
160    }
161
162    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
163        self.inner.supported_verify_schemes()
164    }
165}
166
167/// Load PEM-encoded certificates from a file.
168fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
169    let file = std::fs::File::open(path)
170        .with_context(|| format!("cannot open certificate file: {path}"))?;
171    let mut reader = std::io::BufReader::new(file);
172    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
173        .collect::<std::result::Result<Vec<_>, _>>()
174        .with_context(|| format!("failed to parse PEM certificates from {path}"))?;
175    if certs.is_empty() {
176        anyhow::bail!("no certificates found in {path}");
177    }
178    Ok(certs)
179}
180
181/// Load a PEM-encoded private key from a file.
182fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>> {
183    let file = std::fs::File::open(path)
184        .with_context(|| format!("cannot open private key file: {path}"))?;
185    let mut reader = std::io::BufReader::new(file);
186    let key = rustls_pemfile::private_key(&mut reader)
187        .with_context(|| format!("failed to parse private key from {path}"))?
188        .ok_or_else(|| {
189            ::zeroclaw_log::record!(
190                ERROR,
191                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
192                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
193                    .with_attrs(::serde_json::json!({"path": path})),
194                "TLS private key file contains no key"
195            );
196            anyhow::Error::msg(format!("no private key found in {path}"))
197        })?;
198    Ok(key)
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    /// Ensure the rustls `CryptoProvider` is installed (idempotent).
206    fn ensure_crypto_provider() {
207        let _ = rustls::crypto::ring::default_provider().install_default();
208    }
209
210    /// Generate a self-signed CA cert + key pair.
211    /// Returns (cert_pem, key_pem, key_pair) so the key can be reused for signing.
212    fn test_ca() -> (String, String, rcgen::KeyPair) {
213        let ca_key = rcgen::KeyPair::generate().unwrap();
214        let mut ca_params = rcgen::CertificateParams::new(vec!["Test CA".into()]).unwrap();
215        ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
216        let ca_cert = ca_params.self_signed(&ca_key).unwrap();
217        (ca_cert.pem(), ca_key.serialize_pem(), ca_key)
218    }
219
220    /// Generate a server certificate signed by the given CA.
221    fn test_server_cert(ca_cert_pem: &str, ca_key: &rcgen::KeyPair) -> (String, String) {
222        // Re-parse the CA cert for signing.
223        let ca_key_clone = rcgen::KeyPair::from_pem(&ca_key.serialize_pem()).unwrap();
224        let mut ca_params = rcgen::CertificateParams::new(vec!["Test CA".into()]).unwrap();
225        ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
226        let ca = ca_params.self_signed(&ca_key_clone).unwrap();
227
228        let mut server_params = rcgen::CertificateParams::new(vec!["localhost".into()]).unwrap();
229        server_params.is_ca = rcgen::IsCa::NoCa;
230        let server_key = rcgen::KeyPair::generate().unwrap();
231        let server_cert = server_params
232            .signed_by(&server_key, &ca, &ca_key_clone)
233            .unwrap();
234        let _ = ca_cert_pem;
235        (server_cert.pem(), server_key.serialize_pem())
236    }
237
238    fn write_temp_file(content: &str) -> tempfile::NamedTempFile {
239        use std::io::Write;
240        let mut f = tempfile::NamedTempFile::new().unwrap();
241        f.write_all(content.as_bytes()).unwrap();
242        f.flush().unwrap();
243        f
244    }
245
246    #[test]
247    fn test_load_valid_cert_and_key() {
248        let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
249        let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
250
251        let cert_file = write_temp_file(&server_cert_pem);
252        let key_file = write_temp_file(&server_key_pem);
253
254        let certs = load_certs(cert_file.path().to_str().unwrap()).unwrap();
255        assert!(!certs.is_empty());
256
257        let _key = load_private_key(key_file.path().to_str().unwrap()).unwrap();
258    }
259
260    #[test]
261    fn test_invalid_cert_path_produces_clear_error() {
262        let err = load_certs("/nonexistent/path/cert.pem").unwrap_err();
263        let msg = format!("{err:#}");
264        assert!(
265            msg.contains("cannot open certificate file"),
266            "unexpected error: {msg}"
267        );
268    }
269
270    #[test]
271    fn test_invalid_key_path_produces_clear_error() {
272        let err = load_private_key("/nonexistent/path/key.pem").unwrap_err();
273        let msg = format!("{err:#}");
274        assert!(
275            msg.contains("cannot open private key file"),
276            "unexpected error: {msg}"
277        );
278    }
279
280    #[test]
281    fn test_build_server_config_no_client_auth() {
282        ensure_crypto_provider();
283        let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
284        let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
285
286        let cert_file = write_temp_file(&server_cert_pem);
287        let key_file = write_temp_file(&server_key_pem);
288
289        let tls_config = GatewayTlsConfig {
290            enabled: true,
291            cert_path: cert_file.path().to_str().unwrap().to_string(),
292            key_path: key_file.path().to_str().unwrap().to_string(),
293            client_auth: None,
294        };
295
296        // Should build successfully without client auth.
297        let _server_config = build_server_config(&tls_config).unwrap();
298    }
299
300    #[test]
301    fn test_build_server_config_with_client_auth() {
302        ensure_crypto_provider();
303        let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
304        let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
305
306        let cert_file = write_temp_file(&server_cert_pem);
307        let key_file = write_temp_file(&server_key_pem);
308        let ca_file = write_temp_file(&ca_cert_pem);
309
310        let tls_config = GatewayTlsConfig {
311            enabled: true,
312            cert_path: cert_file.path().to_str().unwrap().to_string(),
313            key_path: key_file.path().to_str().unwrap().to_string(),
314            client_auth: Some(GatewayClientAuthConfig {
315                enabled: true,
316                ca_cert_path: ca_file.path().to_str().unwrap().to_string(),
317                require_client_cert: true,
318                pinned_certs: vec![],
319            }),
320        };
321
322        // Should build successfully with mandatory client auth.
323        let _server_config = build_server_config(&tls_config).unwrap();
324    }
325
326    #[test]
327    fn test_build_server_config_client_auth_optional() {
328        ensure_crypto_provider();
329        let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
330        let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
331
332        let cert_file = write_temp_file(&server_cert_pem);
333        let key_file = write_temp_file(&server_key_pem);
334        let ca_file = write_temp_file(&ca_cert_pem);
335
336        let tls_config = GatewayTlsConfig {
337            enabled: true,
338            cert_path: cert_file.path().to_str().unwrap().to_string(),
339            key_path: key_file.path().to_str().unwrap().to_string(),
340            client_auth: Some(GatewayClientAuthConfig {
341                enabled: true,
342                ca_cert_path: ca_file.path().to_str().unwrap().to_string(),
343                require_client_cert: false,
344                pinned_certs: vec![],
345            }),
346        };
347
348        // Should build successfully with optional client auth.
349        let _server_config = build_server_config(&tls_config).unwrap();
350    }
351
352    #[test]
353    fn test_cert_fingerprint_matching() {
354        let (ca_cert_pem, _ca_key_pem, _ca_key) = test_ca();
355        let ca_file = write_temp_file(&ca_cert_pem);
356        let certs = load_certs(ca_file.path().to_str().unwrap()).unwrap();
357        let fingerprint = cert_sha256_fingerprint(certs[0].as_ref());
358
359        // Fingerprint should be a 64-char hex string (SHA-256).
360        assert_eq!(fingerprint.len(), 64);
361        assert!(fingerprint.chars().all(|c| c.is_ascii_hexdigit()));
362
363        // Same cert should produce the same fingerprint.
364        let fingerprint2 = cert_sha256_fingerprint(certs[0].as_ref());
365        assert_eq!(fingerprint, fingerprint2);
366    }
367
368    #[test]
369    fn test_fingerprint_differs_for_different_certs() {
370        let (ca_cert_pem1, _, _) = test_ca();
371        let (ca_cert_pem2, _, _) = test_ca();
372        let f1 = write_temp_file(&ca_cert_pem1);
373        let f2 = write_temp_file(&ca_cert_pem2);
374        let certs1 = load_certs(f1.path().to_str().unwrap()).unwrap();
375        let certs2 = load_certs(f2.path().to_str().unwrap()).unwrap();
376        let fp1 = cert_sha256_fingerprint(certs1[0].as_ref());
377        let fp2 = cert_sha256_fingerprint(certs2[0].as_ref());
378        assert_ne!(fp1, fp2);
379    }
380
381    #[test]
382    fn test_config_defaults_deserialization() {
383        let toml_str = r#"
384            cert_path = "/tmp/cert.pem"
385            key_path = "/tmp/key.pem"
386        "#;
387        let config: GatewayTlsConfig = toml::from_str(toml_str).unwrap();
388        assert!(!config.enabled);
389        assert!(config.client_auth.is_none());
390    }
391
392    #[test]
393    fn test_client_auth_config_defaults() {
394        let toml_str = r#"
395            ca_cert_path = "/tmp/ca.pem"
396        "#;
397        let config: GatewayClientAuthConfig = toml::from_str(toml_str).unwrap();
398        assert!(!config.enabled);
399        assert!(config.require_client_cert);
400        assert!(config.pinned_certs.is_empty());
401    }
402
403    #[test]
404    fn test_build_server_config_with_pinning() {
405        ensure_crypto_provider();
406        let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
407        let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
408
409        let cert_file = write_temp_file(&server_cert_pem);
410        let key_file = write_temp_file(&server_key_pem);
411        let ca_file = write_temp_file(&ca_cert_pem);
412
413        let tls_config = GatewayTlsConfig {
414            enabled: true,
415            cert_path: cert_file.path().to_str().unwrap().to_string(),
416            key_path: key_file.path().to_str().unwrap().to_string(),
417            client_auth: Some(GatewayClientAuthConfig {
418                enabled: true,
419                ca_cert_path: ca_file.path().to_str().unwrap().to_string(),
420                require_client_cert: true,
421                pinned_certs: vec!["aabbccdd".to_string()],
422            }),
423        };
424
425        // Should build successfully - pinning is checked at connection time, not config time.
426        let _server_config = build_server_config(&tls_config).unwrap();
427    }
428
429    #[test]
430    fn test_empty_cert_file_produces_error() {
431        let empty_file = write_temp_file("");
432        let err = load_certs(empty_file.path().to_str().unwrap()).unwrap_err();
433        let msg = format!("{err:#}");
434        assert!(
435            msg.contains("no certificates found"),
436            "unexpected error: {msg}"
437        );
438    }
439
440    #[test]
441    fn test_disabled_client_auth_skipped() {
442        ensure_crypto_provider();
443        let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
444        let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
445
446        let cert_file = write_temp_file(&server_cert_pem);
447        let key_file = write_temp_file(&server_key_pem);
448
449        // client_auth present but enabled=false should be treated as no client auth.
450        let tls_config = GatewayTlsConfig {
451            enabled: true,
452            cert_path: cert_file.path().to_str().unwrap().to_string(),
453            key_path: key_file.path().to_str().unwrap().to_string(),
454            client_auth: Some(GatewayClientAuthConfig {
455                enabled: false,
456                ca_cert_path: "/nonexistent".to_string(),
457                require_client_cert: true,
458                pinned_certs: vec![],
459            }),
460        };
461
462        // Should succeed because client_auth.enabled=false skips the CA loading.
463        let _server_config = build_server_config(&tls_config).unwrap();
464    }
465}