1use anyhow::{Context, Result};
8use rustls::RootCertStore;
9use rustls::pki_types::{CertificateDer, PrivateKeyDer};
10use rustls::server::WebPkiClientVerifier;
11use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
12use sha2::{Digest, Sha256};
13use std::sync::Arc;
14use tokio_rustls::TlsAcceptor;
15use zeroclaw_config::schema::{GatewayClientAuthConfig, GatewayTlsConfig};
16
17pub fn build_tls_acceptor(config: &GatewayTlsConfig) -> Result<TlsAcceptor> {
19 let server_config = build_server_config(config)?;
20 Ok(TlsAcceptor::from(Arc::new(server_config)))
21}
22
23pub fn build_server_config(config: &GatewayTlsConfig) -> Result<rustls::ServerConfig> {
25 let certs = load_certs(&config.cert_path).with_context(|| {
26 format!(
27 "failed to load server certificate from {}",
28 config.cert_path
29 )
30 })?;
31 let key = load_private_key(&config.key_path)
32 .with_context(|| format!("failed to load private key from {}", config.key_path))?;
33
34 let client_auth_config = config.client_auth.as_ref().filter(|ca| ca.enabled);
35
36 let builder = rustls::ServerConfig::builder();
37
38 let server_config = if let Some(client_auth) = client_auth_config {
39 let verifier = build_client_verifier(client_auth)
40 .context("failed to build client certificate verifier")?;
41 builder
42 .with_client_cert_verifier(verifier)
43 .with_single_cert(certs, key)
44 .context("invalid server certificate or key")?
45 } else {
46 builder
47 .with_no_client_auth()
48 .with_single_cert(certs, key)
49 .context("invalid server certificate or key")?
50 };
51
52 Ok(server_config)
53}
54
55fn build_client_verifier(config: &GatewayClientAuthConfig) -> Result<Arc<dyn ClientCertVerifier>> {
57 let ca_certs = load_certs(&config.ca_cert_path)
58 .with_context(|| format!("failed to load CA certificate from {}", config.ca_cert_path))?;
59
60 let mut root_store = RootCertStore::empty();
61 for cert in &ca_certs {
62 root_store
63 .add(cert.clone())
64 .context("failed to add CA certificate to root store")?;
65 }
66
67 let base_verifier = if config.require_client_cert {
68 WebPkiClientVerifier::builder(Arc::new(root_store))
69 .build()
70 .context("failed to build WebPKI client verifier")?
71 } else {
72 WebPkiClientVerifier::builder(Arc::new(root_store))
73 .allow_unauthenticated()
74 .build()
75 .context("failed to build WebPKI client verifier (optional auth)")?
76 };
77
78 if config.pinned_certs.is_empty() {
79 Ok(base_verifier)
80 } else {
81 let normalized: Vec<String> = config
82 .pinned_certs
83 .iter()
84 .map(|fp| fp.replace(':', "").to_lowercase())
85 .collect();
86 Ok(Arc::new(PinnedCertVerifier {
87 inner: base_verifier,
88 pinned_fingerprints: normalized,
89 }))
90 }
91}
92
93pub fn cert_sha256_fingerprint(cert_der: &[u8]) -> String {
95 let mut hasher = Sha256::new();
96 hasher.update(cert_der);
97 let hash = hasher.finalize();
98 hex::encode(hash)
99}
100
101#[derive(Debug)]
105struct PinnedCertVerifier {
106 inner: Arc<dyn ClientCertVerifier>,
107 pinned_fingerprints: Vec<String>,
108}
109
110impl ClientCertVerifier for PinnedCertVerifier {
111 fn offer_client_auth(&self) -> bool {
112 self.inner.offer_client_auth()
113 }
114
115 fn client_auth_mandatory(&self) -> bool {
116 self.inner.client_auth_mandatory()
117 }
118
119 fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
120 self.inner.root_hint_subjects()
121 }
122
123 fn verify_client_cert(
124 &self,
125 end_entity: &CertificateDer<'_>,
126 intermediates: &[CertificateDer<'_>],
127 now: rustls::pki_types::UnixTime,
128 ) -> std::result::Result<ClientCertVerified, rustls::Error> {
129 self.inner
131 .verify_client_cert(end_entity, intermediates, now)?;
132
133 let fingerprint = cert_sha256_fingerprint(end_entity.as_ref());
135 if self.pinned_fingerprints.contains(&fingerprint) {
136 Ok(ClientCertVerified::assertion())
137 } else {
138 Err(rustls::Error::General(format!(
139 "client certificate fingerprint {fingerprint} is not in the pinned set"
140 )))
141 }
142 }
143
144 fn verify_tls12_signature(
145 &self,
146 message: &[u8],
147 cert: &CertificateDer<'_>,
148 dss: &rustls::DigitallySignedStruct,
149 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
150 self.inner.verify_tls12_signature(message, cert, dss)
151 }
152
153 fn verify_tls13_signature(
154 &self,
155 message: &[u8],
156 cert: &CertificateDer<'_>,
157 dss: &rustls::DigitallySignedStruct,
158 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
159 self.inner.verify_tls13_signature(message, cert, dss)
160 }
161
162 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
163 self.inner.supported_verify_schemes()
164 }
165}
166
167fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
169 let file = std::fs::File::open(path)
170 .with_context(|| format!("cannot open certificate file: {path}"))?;
171 let mut reader = std::io::BufReader::new(file);
172 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
173 .collect::<std::result::Result<Vec<_>, _>>()
174 .with_context(|| format!("failed to parse PEM certificates from {path}"))?;
175 if certs.is_empty() {
176 anyhow::bail!("no certificates found in {path}");
177 }
178 Ok(certs)
179}
180
181fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>> {
183 let file = std::fs::File::open(path)
184 .with_context(|| format!("cannot open private key file: {path}"))?;
185 let mut reader = std::io::BufReader::new(file);
186 let key = rustls_pemfile::private_key(&mut reader)
187 .with_context(|| format!("failed to parse private key from {path}"))?
188 .ok_or_else(|| {
189 ::zeroclaw_log::record!(
190 ERROR,
191 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
192 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
193 .with_attrs(::serde_json::json!({"path": path})),
194 "TLS private key file contains no key"
195 );
196 anyhow::Error::msg(format!("no private key found in {path}"))
197 })?;
198 Ok(key)
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 fn ensure_crypto_provider() {
207 let _ = rustls::crypto::ring::default_provider().install_default();
208 }
209
210 fn test_ca() -> (String, String, rcgen::KeyPair) {
213 let ca_key = rcgen::KeyPair::generate().unwrap();
214 let mut ca_params = rcgen::CertificateParams::new(vec!["Test CA".into()]).unwrap();
215 ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
216 let ca_cert = ca_params.self_signed(&ca_key).unwrap();
217 (ca_cert.pem(), ca_key.serialize_pem(), ca_key)
218 }
219
220 fn test_server_cert(ca_cert_pem: &str, ca_key: &rcgen::KeyPair) -> (String, String) {
222 let ca_key_clone = rcgen::KeyPair::from_pem(&ca_key.serialize_pem()).unwrap();
224 let mut ca_params = rcgen::CertificateParams::new(vec!["Test CA".into()]).unwrap();
225 ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
226 let ca = ca_params.self_signed(&ca_key_clone).unwrap();
227
228 let mut server_params = rcgen::CertificateParams::new(vec!["localhost".into()]).unwrap();
229 server_params.is_ca = rcgen::IsCa::NoCa;
230 let server_key = rcgen::KeyPair::generate().unwrap();
231 let server_cert = server_params
232 .signed_by(&server_key, &ca, &ca_key_clone)
233 .unwrap();
234 let _ = ca_cert_pem;
235 (server_cert.pem(), server_key.serialize_pem())
236 }
237
238 fn write_temp_file(content: &str) -> tempfile::NamedTempFile {
239 use std::io::Write;
240 let mut f = tempfile::NamedTempFile::new().unwrap();
241 f.write_all(content.as_bytes()).unwrap();
242 f.flush().unwrap();
243 f
244 }
245
246 #[test]
247 fn test_load_valid_cert_and_key() {
248 let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
249 let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
250
251 let cert_file = write_temp_file(&server_cert_pem);
252 let key_file = write_temp_file(&server_key_pem);
253
254 let certs = load_certs(cert_file.path().to_str().unwrap()).unwrap();
255 assert!(!certs.is_empty());
256
257 let _key = load_private_key(key_file.path().to_str().unwrap()).unwrap();
258 }
259
260 #[test]
261 fn test_invalid_cert_path_produces_clear_error() {
262 let err = load_certs("/nonexistent/path/cert.pem").unwrap_err();
263 let msg = format!("{err:#}");
264 assert!(
265 msg.contains("cannot open certificate file"),
266 "unexpected error: {msg}"
267 );
268 }
269
270 #[test]
271 fn test_invalid_key_path_produces_clear_error() {
272 let err = load_private_key("/nonexistent/path/key.pem").unwrap_err();
273 let msg = format!("{err:#}");
274 assert!(
275 msg.contains("cannot open private key file"),
276 "unexpected error: {msg}"
277 );
278 }
279
280 #[test]
281 fn test_build_server_config_no_client_auth() {
282 ensure_crypto_provider();
283 let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
284 let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
285
286 let cert_file = write_temp_file(&server_cert_pem);
287 let key_file = write_temp_file(&server_key_pem);
288
289 let tls_config = GatewayTlsConfig {
290 enabled: true,
291 cert_path: cert_file.path().to_str().unwrap().to_string(),
292 key_path: key_file.path().to_str().unwrap().to_string(),
293 client_auth: None,
294 };
295
296 let _server_config = build_server_config(&tls_config).unwrap();
298 }
299
300 #[test]
301 fn test_build_server_config_with_client_auth() {
302 ensure_crypto_provider();
303 let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
304 let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
305
306 let cert_file = write_temp_file(&server_cert_pem);
307 let key_file = write_temp_file(&server_key_pem);
308 let ca_file = write_temp_file(&ca_cert_pem);
309
310 let tls_config = GatewayTlsConfig {
311 enabled: true,
312 cert_path: cert_file.path().to_str().unwrap().to_string(),
313 key_path: key_file.path().to_str().unwrap().to_string(),
314 client_auth: Some(GatewayClientAuthConfig {
315 enabled: true,
316 ca_cert_path: ca_file.path().to_str().unwrap().to_string(),
317 require_client_cert: true,
318 pinned_certs: vec![],
319 }),
320 };
321
322 let _server_config = build_server_config(&tls_config).unwrap();
324 }
325
326 #[test]
327 fn test_build_server_config_client_auth_optional() {
328 ensure_crypto_provider();
329 let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
330 let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
331
332 let cert_file = write_temp_file(&server_cert_pem);
333 let key_file = write_temp_file(&server_key_pem);
334 let ca_file = write_temp_file(&ca_cert_pem);
335
336 let tls_config = GatewayTlsConfig {
337 enabled: true,
338 cert_path: cert_file.path().to_str().unwrap().to_string(),
339 key_path: key_file.path().to_str().unwrap().to_string(),
340 client_auth: Some(GatewayClientAuthConfig {
341 enabled: true,
342 ca_cert_path: ca_file.path().to_str().unwrap().to_string(),
343 require_client_cert: false,
344 pinned_certs: vec![],
345 }),
346 };
347
348 let _server_config = build_server_config(&tls_config).unwrap();
350 }
351
352 #[test]
353 fn test_cert_fingerprint_matching() {
354 let (ca_cert_pem, _ca_key_pem, _ca_key) = test_ca();
355 let ca_file = write_temp_file(&ca_cert_pem);
356 let certs = load_certs(ca_file.path().to_str().unwrap()).unwrap();
357 let fingerprint = cert_sha256_fingerprint(certs[0].as_ref());
358
359 assert_eq!(fingerprint.len(), 64);
361 assert!(fingerprint.chars().all(|c| c.is_ascii_hexdigit()));
362
363 let fingerprint2 = cert_sha256_fingerprint(certs[0].as_ref());
365 assert_eq!(fingerprint, fingerprint2);
366 }
367
368 #[test]
369 fn test_fingerprint_differs_for_different_certs() {
370 let (ca_cert_pem1, _, _) = test_ca();
371 let (ca_cert_pem2, _, _) = test_ca();
372 let f1 = write_temp_file(&ca_cert_pem1);
373 let f2 = write_temp_file(&ca_cert_pem2);
374 let certs1 = load_certs(f1.path().to_str().unwrap()).unwrap();
375 let certs2 = load_certs(f2.path().to_str().unwrap()).unwrap();
376 let fp1 = cert_sha256_fingerprint(certs1[0].as_ref());
377 let fp2 = cert_sha256_fingerprint(certs2[0].as_ref());
378 assert_ne!(fp1, fp2);
379 }
380
381 #[test]
382 fn test_config_defaults_deserialization() {
383 let toml_str = r#"
384 cert_path = "/tmp/cert.pem"
385 key_path = "/tmp/key.pem"
386 "#;
387 let config: GatewayTlsConfig = toml::from_str(toml_str).unwrap();
388 assert!(!config.enabled);
389 assert!(config.client_auth.is_none());
390 }
391
392 #[test]
393 fn test_client_auth_config_defaults() {
394 let toml_str = r#"
395 ca_cert_path = "/tmp/ca.pem"
396 "#;
397 let config: GatewayClientAuthConfig = toml::from_str(toml_str).unwrap();
398 assert!(!config.enabled);
399 assert!(config.require_client_cert);
400 assert!(config.pinned_certs.is_empty());
401 }
402
403 #[test]
404 fn test_build_server_config_with_pinning() {
405 ensure_crypto_provider();
406 let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
407 let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
408
409 let cert_file = write_temp_file(&server_cert_pem);
410 let key_file = write_temp_file(&server_key_pem);
411 let ca_file = write_temp_file(&ca_cert_pem);
412
413 let tls_config = GatewayTlsConfig {
414 enabled: true,
415 cert_path: cert_file.path().to_str().unwrap().to_string(),
416 key_path: key_file.path().to_str().unwrap().to_string(),
417 client_auth: Some(GatewayClientAuthConfig {
418 enabled: true,
419 ca_cert_path: ca_file.path().to_str().unwrap().to_string(),
420 require_client_cert: true,
421 pinned_certs: vec!["aabbccdd".to_string()],
422 }),
423 };
424
425 let _server_config = build_server_config(&tls_config).unwrap();
427 }
428
429 #[test]
430 fn test_empty_cert_file_produces_error() {
431 let empty_file = write_temp_file("");
432 let err = load_certs(empty_file.path().to_str().unwrap()).unwrap_err();
433 let msg = format!("{err:#}");
434 assert!(
435 msg.contains("no certificates found"),
436 "unexpected error: {msg}"
437 );
438 }
439
440 #[test]
441 fn test_disabled_client_auth_skipped() {
442 ensure_crypto_provider();
443 let (ca_cert_pem, _ca_key_pem, ca_key) = test_ca();
444 let (server_cert_pem, server_key_pem) = test_server_cert(&ca_cert_pem, &ca_key);
445
446 let cert_file = write_temp_file(&server_cert_pem);
447 let key_file = write_temp_file(&server_key_pem);
448
449 let tls_config = GatewayTlsConfig {
451 enabled: true,
452 cert_path: cert_file.path().to_str().unwrap().to_string(),
453 key_path: key_file.path().to_str().unwrap().to_string(),
454 client_auth: Some(GatewayClientAuthConfig {
455 enabled: false,
456 ca_cert_path: "/nonexistent".to_string(),
457 require_client_cert: true,
458 pinned_certs: vec![],
459 }),
460 };
461
462 let _server_config = build_server_config(&tls_config).unwrap();
464 }
465}