1use crate::traits::{
11 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
12 ModelProvider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
13};
14use async_trait::async_trait;
15use hmac::{Hmac, Mac};
16use reqwest::Client;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19use std::sync::Mutex;
20use zeroclaw_api::tool::ToolSpec;
21
22const ENDPOINT_PREFIX: &str = "bedrock-runtime";
24const SIGNING_SERVICE: &str = "bedrock";
26const DEFAULT_REGION: &str = "us-east-1";
27
28enum BedrockAuth {
32 SigV4(AwsCredentials),
33 BearerToken(String),
34}
35
36#[derive(Clone)]
40struct AwsCredentials {
41 access_key_id: String,
42 secret_access_key: String,
43 session_token: Option<String>,
44 region: String,
45 expires_at: Option<chrono::DateTime<chrono::Utc>>,
48}
49
50impl AwsCredentials {
51 fn from_env() -> anyhow::Result<Self> {
53 let access_key_id = env_required("AWS_ACCESS_KEY_ID")?;
54 let secret_access_key = env_required("AWS_SECRET_ACCESS_KEY")?;
55
56 let session_token = env_optional("AWS_SESSION_TOKEN");
57
58 let region = env_optional("AWS_REGION")
59 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
60 .unwrap_or_else(|| DEFAULT_REGION.to_string());
61
62 Ok(Self {
63 access_key_id,
64 secret_access_key,
65 session_token,
66 region,
67 expires_at: None,
68 })
69 }
70
71 fn parse_aws_config(content: &str, profile: &str) -> Option<(String, Option<String>)> {
74 let target = if profile == "default" {
75 "[default]".to_string()
76 } else {
77 format!("[profile {profile}]")
78 };
79
80 let mut in_section = false;
81 let mut cred_process = None;
82 let mut region = None;
83
84 for line in content.lines() {
85 let trimmed = line.trim();
86 if trimmed.starts_with('[') {
87 in_section = trimmed == target;
88 continue;
89 }
90 if !in_section || trimmed.starts_with('#') || trimmed.starts_with(';') {
91 continue;
92 }
93 if let Some((key, value)) = trimmed.split_once('=') {
94 match key.trim() {
95 "credential_process" => cred_process = Some(value.trim().to_string()),
96 "region" => region = Some(value.trim().to_string()),
97 _ => {}
98 }
99 }
100 }
101 cred_process.map(|cmd| (cmd, region))
102 }
103
104 fn from_credential_process() -> anyhow::Result<Self> {
106 let config_path = std::env::var("AWS_CONFIG_FILE").unwrap_or_else(|_| {
107 let home = std::env::var("HOME").unwrap_or_else(|_| "~".to_string());
108 format!("{home}/.aws/config")
109 });
110 let content = std::fs::read_to_string(&config_path).map_err(|e| {
111 ::zeroclaw_log::record!(
112 ERROR,
113 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
114 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
115 .with_attrs(::serde_json::json!({
116 "config_path": &config_path,
117 "error": format!("{}", e),
118 })),
119 "bedrock: cannot read AWS config file"
120 );
121 anyhow::Error::msg(format!("Cannot read {config_path}: {e}"))
122 })?;
123 let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
124 let (cmd, config_region) = Self::parse_aws_config(&content, &profile).ok_or_else(|| {
125 ::zeroclaw_log::record!(
126 ERROR,
127 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
128 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
129 .with_attrs(::serde_json::json!({"profile": &profile})),
130 "bedrock: no credential_process in AWS profile"
131 );
132 anyhow::Error::msg(format!("No credential_process in [{profile}]"))
133 })?;
134
135 let output = std::process::Command::new("sh")
136 .args(["-c", &cmd])
137 .output()
138 .map_err(|e| {
139 ::zeroclaw_log::record!(
140 ERROR,
141 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
142 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
143 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
144 "bedrock: failed to spawn credential_process"
145 );
146 anyhow::Error::msg(format!("Failed to run credential_process: {e}"))
147 })?;
148 anyhow::ensure!(
149 output.status.success(),
150 "credential_process exited with {}: {}",
151 output.status,
152 String::from_utf8_lossy(&output.stderr).trim()
153 );
154
155 let json: serde_json::Value = serde_json::from_slice(&output.stdout).map_err(|e| {
156 ::zeroclaw_log::record!(
157 ERROR,
158 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
159 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
160 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
161 "bedrock: credential_process output is not valid JSON"
162 );
163 anyhow::Error::msg(format!("credential_process output is not valid JSON: {e}"))
164 })?;
165
166 let access_key_id = json["AccessKeyId"]
167 .as_str()
168 .ok_or_else(|| {
169 ::zeroclaw_log::record!(
170 ERROR,
171 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
172 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
173 .with_attrs(::serde_json::json!({"missing": "AccessKeyId"})),
174 "bedrock: credential_process missing AccessKeyId"
175 );
176 anyhow::Error::msg("Missing AccessKeyId in credential_process output")
177 })?
178 .to_string();
179 let secret_access_key = json["SecretAccessKey"]
180 .as_str()
181 .ok_or_else(|| {
182 ::zeroclaw_log::record!(
183 ERROR,
184 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
185 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
186 .with_attrs(::serde_json::json!({"missing": "SecretAccessKey"})),
187 "bedrock: credential_process missing SecretAccessKey"
188 );
189 anyhow::Error::msg("Missing SecretAccessKey in credential_process output")
190 })?
191 .to_string();
192 let session_token = json["SessionToken"].as_str().map(|s| s.to_string());
193
194 let expires_at = json["Expiration"]
195 .as_str()
196 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
197 .map(|dt| dt.with_timezone(&chrono::Utc));
198
199 let region = env_optional("AWS_REGION")
200 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
201 .or(config_region)
202 .unwrap_or_else(|| DEFAULT_REGION.to_string());
203
204 ::zeroclaw_log::record!(
205 DEBUG,
206 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
207 "Loaded AWS credentials via credential_process"
208 );
209
210 Ok(Self {
211 access_key_id,
212 secret_access_key,
213 session_token,
214 region,
215 expires_at,
216 })
217 }
218
219 async fn from_imds() -> anyhow::Result<Self> {
221 let client = reqwest::Client::builder()
222 .timeout(std::time::Duration::from_secs(3))
223 .build()?;
224
225 let token = client
227 .put("http://169.254.169.254/latest/api/token")
228 .header("X-aws-ec2-metadata-token-ttl-seconds", "21600")
229 .send()
230 .await?
231 .text()
232 .await?;
233
234 let role = client
236 .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
237 .header("X-aws-ec2-metadata-token", &token)
238 .send()
239 .await?
240 .text()
241 .await?;
242 let role = role.trim().to_string();
243 anyhow::ensure!(!role.is_empty(), "No IAM role attached to this instance");
244
245 let creds_url = format!(
247 "http://169.254.169.254/latest/meta-data/iam/security-credentials/{}",
248 role
249 );
250 let creds_json: serde_json::Value = client
251 .get(&creds_url)
252 .header("X-aws-ec2-metadata-token", &token)
253 .send()
254 .await?
255 .json()
256 .await?;
257
258 let access_key_id = creds_json["AccessKeyId"]
259 .as_str()
260 .ok_or_else(|| {
261 ::zeroclaw_log::record!(
262 ERROR,
263 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
264 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
265 .with_attrs(::serde_json::json!({
266 "source": "imds",
267 "missing": "AccessKeyId",
268 })),
269 "bedrock: IMDS response missing AccessKeyId"
270 );
271 anyhow::Error::msg("Missing AccessKeyId in IMDS response")
272 })?
273 .to_string();
274 let secret_access_key = creds_json["SecretAccessKey"]
275 .as_str()
276 .ok_or_else(|| {
277 ::zeroclaw_log::record!(
278 ERROR,
279 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
280 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
281 .with_attrs(::serde_json::json!({
282 "source": "imds",
283 "missing": "SecretAccessKey",
284 })),
285 "bedrock: IMDS response missing SecretAccessKey"
286 );
287 anyhow::Error::msg("Missing SecretAccessKey in IMDS response")
288 })?
289 .to_string();
290 let session_token = creds_json["Token"].as_str().map(|s| s.to_string());
291
292 let region = match client
294 .get("http://169.254.169.254/latest/meta-data/placement/region")
295 .header("X-aws-ec2-metadata-token", &token)
296 .send()
297 .await
298 {
299 Ok(resp) => resp.text().await.unwrap_or_default(),
300 Err(_) => String::new(),
301 };
302 let region = if region.trim().is_empty() {
303 env_optional("AWS_REGION")
304 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
305 .unwrap_or_else(|| DEFAULT_REGION.to_string())
306 } else {
307 region.trim().to_string()
308 };
309
310 ::zeroclaw_log::record!(
311 INFO,
312 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
313 &format!(
314 "Loaded AWS credentials from EC2 instance metadata (role: {})",
315 role
316 )
317 );
318
319 Ok(Self {
320 access_key_id,
321 secret_access_key,
322 session_token,
323 region,
324 expires_at: None,
325 })
326 }
327
328 async fn resolve() -> anyhow::Result<Self> {
330 if let Ok(creds) = Self::from_env() {
331 return Ok(creds);
332 }
333 if let Ok(creds) = Self::from_credential_process() {
334 return Ok(creds);
335 }
336 Self::from_imds().await
337 }
338
339 fn host(&self) -> String {
340 format!("{ENDPOINT_PREFIX}.{}.amazonaws.com", self.region)
341 }
342
343 fn is_expired(&self) -> bool {
346 match self.expires_at {
347 Some(exp) => chrono::Utc::now() >= exp - chrono::Duration::seconds(60),
348 None => false,
349 }
350 }
351}
352
353fn env_required(name: &str) -> anyhow::Result<String> {
354 std::env::var(name)
355 .ok()
356 .map(|v| v.trim().to_string())
357 .filter(|v| !v.is_empty())
358 .ok_or_else(|| {
359 ::zeroclaw_log::record!(
360 ERROR,
361 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
362 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
363 .with_attrs(::serde_json::json!({"env_var": name})),
364 "bedrock: required environment variable is missing"
365 );
366 anyhow::Error::msg(format!(
367 "Environment variable {name} is required for Bedrock"
368 ))
369 })
370}
371
372fn env_optional(name: &str) -> Option<String> {
373 std::env::var(name)
374 .ok()
375 .map(|v| v.trim().to_string())
376 .filter(|v| !v.is_empty())
377}
378
379fn sha256_hex(data: &[u8]) -> String {
382 let mut hasher = Sha256::new();
383 hasher.update(data);
384 hex::encode(hasher.finalize())
385}
386
387fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
388 let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
389 mac.update(data);
390 mac.finalize().into_bytes().to_vec()
391}
392
393fn derive_signing_key(secret: &str, date: &str, region: &str, service: &str) -> Vec<u8> {
395 let k_date = hmac_sha256(format!("AWS4{secret}").as_bytes(), date.as_bytes());
396 let k_region = hmac_sha256(&k_date, region.as_bytes());
397 let k_service = hmac_sha256(&k_region, service.as_bytes());
398 hmac_sha256(&k_service, b"aws4_request")
399}
400
401fn build_authorization_header(
405 credentials: &AwsCredentials,
406 method: &str,
407 canonical_uri: &str,
408 query_string: &str,
409 headers: &[(String, String)],
410 payload: &[u8],
411 timestamp: &chrono::DateTime<chrono::Utc>,
412) -> String {
413 let date_stamp = timestamp.format("%Y%m%d").to_string();
414 let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
415
416 let mut canonical_headers = String::new();
417 for (k, v) in headers {
418 canonical_headers.push_str(k);
419 canonical_headers.push(':');
420 canonical_headers.push_str(v);
421 canonical_headers.push('\n');
422 }
423
424 let signed_headers: String = headers
425 .iter()
426 .map(|(k, _)| k.as_str())
427 .collect::<Vec<_>>()
428 .join(";");
429
430 let payload_hash = sha256_hex(payload);
431
432 let canonical_request = format!(
433 "{method}\n{canonical_uri}\n{query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
434 );
435
436 let credential_scope = format!(
437 "{date_stamp}/{}/{SIGNING_SERVICE}/aws4_request",
438 credentials.region
439 );
440
441 let string_to_sign = format!(
442 "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}",
443 sha256_hex(canonical_request.as_bytes())
444 );
445
446 let signing_key = derive_signing_key(
447 &credentials.secret_access_key,
448 &date_stamp,
449 &credentials.region,
450 SIGNING_SERVICE,
451 );
452
453 let signature = hex::encode(hmac_sha256(&signing_key, string_to_sign.as_bytes()));
454
455 format!(
456 "AWS4-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
457 credentials.access_key_id
458 )
459}
460
461#[derive(Debug, Serialize)]
464#[serde(rename_all = "camelCase")]
465struct ConverseRequest {
466 messages: Vec<ConverseMessage>,
467 #[serde(skip_serializing_if = "Option::is_none")]
468 system: Option<Vec<SystemBlock>>,
469 #[serde(skip_serializing_if = "Option::is_none")]
470 inference_config: Option<InferenceConfig>,
471 #[serde(skip_serializing_if = "Option::is_none")]
472 tool_config: Option<ToolConfig>,
473 #[serde(skip_serializing_if = "Option::is_none")]
474 additional_model_request_fields: Option<serde_json::Value>,
475}
476
477#[derive(Debug, Serialize, Deserialize)]
478struct ConverseMessage {
479 role: String,
480 content: Vec<ContentBlock>,
481}
482
483#[derive(Debug, Serialize, Deserialize)]
490#[serde(untagged)]
491enum ContentBlock {
492 Text(TextBlock),
493 ToolUse(ToolUseWrapper),
494 ToolResult(ToolResultWrapper),
495 CachePointBlock(CachePointWrapper),
496 Image(ImageWrapper),
497 #[serde(rename = "reasoningContent")]
501 ReasoningContent(ReasoningContentOutWrapper),
502}
503
504#[derive(Debug, Serialize, Deserialize)]
507#[serde(rename_all = "camelCase")]
508struct ReasoningContentOutWrapper {
509 reasoning_content: ReasoningContentOutBlock,
510}
511
512#[derive(Debug, Serialize, Deserialize)]
513#[serde(rename_all = "camelCase")]
514struct ReasoningContentOutBlock {
515 reasoning_text: ReasoningTextOutField,
516}
517
518#[derive(Debug, Serialize, Deserialize)]
519struct ReasoningTextOutField {
520 text: String,
521 #[serde(skip_serializing_if = "Option::is_none")]
524 signature: Option<String>,
525}
526
527#[derive(Debug, Serialize, Deserialize)]
528struct ImageWrapper {
529 image: ImageBlock,
530}
531
532#[derive(Debug, Serialize, Deserialize)]
533struct ImageBlock {
534 format: String,
535 source: ImageSource,
536}
537
538#[derive(Debug, Serialize, Deserialize)]
539#[serde(rename_all = "camelCase")]
540struct ImageSource {
541 bytes: String,
542}
543
544#[derive(Debug, Serialize, Deserialize)]
545struct TextBlock {
546 text: String,
547}
548
549#[derive(Debug, Serialize, Deserialize)]
550#[serde(rename_all = "camelCase")]
551struct ToolUseWrapper {
552 tool_use: ToolUseBlock,
553}
554
555#[derive(Debug, Serialize, Deserialize)]
556#[serde(rename_all = "camelCase")]
557struct ToolUseBlock {
558 tool_use_id: String,
559 name: String,
560 input: serde_json::Value,
561}
562
563#[derive(Debug, Serialize, Deserialize)]
564#[serde(rename_all = "camelCase")]
565struct ToolResultWrapper {
566 tool_result: ToolResultBlock,
567}
568
569#[derive(Debug, Serialize, Deserialize)]
570#[serde(rename_all = "camelCase")]
571struct ToolResultBlock {
572 tool_use_id: String,
573 content: Vec<ToolResultContent>,
574 status: String,
575}
576
577#[derive(Debug, Serialize, Deserialize)]
578#[serde(rename_all = "camelCase")]
579struct CachePointWrapper {
580 cache_point: CachePoint,
581}
582
583#[derive(Debug, Serialize, Deserialize)]
584struct ToolResultContent {
585 text: String,
586}
587
588#[derive(Debug, Serialize, Deserialize)]
589struct CachePoint {
590 #[serde(rename = "type")]
591 cache_type: String,
592}
593
594impl CachePoint {
595 fn default_cache() -> Self {
596 Self {
597 cache_type: "default".to_string(),
598 }
599 }
600}
601
602#[derive(Debug, Serialize)]
604#[serde(untagged)]
605enum SystemBlock {
606 Text(TextBlock),
607 CachePoint(CachePointWrapper),
608}
609
610#[derive(Debug, Serialize)]
611#[serde(rename_all = "camelCase")]
612struct InferenceConfig {
613 max_tokens: u32,
614 #[serde(skip_serializing_if = "Option::is_none")]
615 temperature: Option<f64>,
616}
617
618fn bedrock_model_supports_native_thinking(model: &str) -> bool {
626 !model.contains("claude-opus-4-7")
627}
628
629fn bedrock_model_supports_prompt_caching(model: &str) -> bool {
639 let model = model.to_ascii_lowercase();
640 model.contains("claude") || model.contains("nova")
641}
642
643#[derive(Debug, Serialize)]
644#[serde(rename_all = "camelCase")]
645struct ToolConfig {
646 tools: Vec<ToolDefinition>,
647}
648
649#[derive(Debug, Serialize)]
650#[serde(rename_all = "camelCase")]
651struct ToolDefinition {
652 tool_spec: ToolSpecDef,
653}
654
655#[derive(Debug, Serialize)]
656#[serde(rename_all = "camelCase")]
657struct ToolSpecDef {
658 name: String,
659 description: String,
660 input_schema: InputSchema,
661}
662
663#[derive(Debug, Serialize)]
664struct InputSchema {
665 json: serde_json::Value,
666}
667
668#[derive(Debug, Deserialize)]
671#[serde(rename_all = "camelCase")]
672struct ConverseResponse {
673 #[serde(default)]
674 output: Option<ConverseOutput>,
675 #[serde(default)]
676 #[allow(dead_code)]
677 stop_reason: Option<String>,
678 #[serde(default)]
679 usage: Option<BedrockUsage>,
680}
681
682#[derive(Debug, Deserialize)]
683#[serde(rename_all = "camelCase")]
684struct BedrockUsage {
685 #[serde(default)]
686 input_tokens: Option<u64>,
687 #[serde(default)]
688 output_tokens: Option<u64>,
689}
690
691#[derive(Debug, Deserialize)]
692struct ConverseOutput {
693 #[serde(default)]
694 message: Option<ConverseOutputMessage>,
695}
696
697#[derive(Debug, Deserialize)]
698struct ConverseOutputMessage {
699 #[allow(dead_code)]
700 role: String,
701 content: Vec<ResponseContentBlock>,
702}
703
704#[derive(Debug, Deserialize)]
711#[serde(untagged)]
712enum ResponseContentBlock {
713 ToolUse(ResponseToolUseWrapper),
714 ReasoningContent(ReasoningContentWrapper),
715 Text(TextBlock),
716 Other(#[allow(dead_code)] serde_json::Value),
717}
718
719#[derive(Debug, Deserialize)]
720#[serde(rename_all = "camelCase")]
721struct ReasoningContentWrapper {
722 reasoning_content: ReasoningContentBlock,
723}
724
725#[derive(Debug, Deserialize)]
726#[serde(rename_all = "camelCase")]
727struct ReasoningContentBlock {
728 #[serde(default)]
729 reasoning_text: Option<ReasoningTextField>,
730}
731
732#[derive(Debug, Deserialize)]
733struct ReasoningTextField {
734 #[serde(default)]
735 text: Option<String>,
736 #[serde(default)]
739 signature: Option<String>,
740}
741
742#[derive(Debug, Deserialize)]
743#[serde(rename_all = "camelCase")]
744struct ResponseToolUseWrapper {
745 tool_use: ToolUseBlock,
746}
747
748pub struct BedrockModelProvider {
751 alias: String,
753 auth: Option<BedrockAuth>,
754 max_tokens: u32,
755 cred_cache: Mutex<Option<AwsCredentials>>,
757}
758
759impl BedrockModelProvider {
760 pub fn new(alias: &str) -> Self {
761 if let Some(token) = env_optional("BEDROCK_API_KEY") {
763 return Self {
764 alias: alias.to_string(),
765 auth: Some(BedrockAuth::BearerToken(token)),
766 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
767 cred_cache: Mutex::new(None),
768 };
769 }
770 Self {
771 alias: alias.to_string(),
772 auth: AwsCredentials::from_env()
773 .or_else(|_| AwsCredentials::from_credential_process())
774 .ok()
775 .map(BedrockAuth::SigV4),
776 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
777 cred_cache: Mutex::new(None),
778 }
779 }
780
781 pub async fn new_async(alias: &str) -> Self {
782 if let Some(token) = env_optional("BEDROCK_API_KEY") {
784 return Self {
785 alias: alias.to_string(),
786 auth: Some(BedrockAuth::BearerToken(token)),
787 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
788 cred_cache: Mutex::new(None),
789 };
790 }
791 let auth = AwsCredentials::resolve().await.ok().map(BedrockAuth::SigV4);
792 Self {
793 alias: alias.to_string(),
794 auth,
795 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
796 cred_cache: Mutex::new(None),
797 }
798 }
799
800 pub fn with_bearer_token(alias: &str, token: &str) -> Self {
802 Self {
803 alias: alias.to_string(),
804 auth: Some(BedrockAuth::BearerToken(token.to_string())),
805 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
806 cred_cache: Mutex::new(None),
807 }
808 }
809 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
811 self.max_tokens = max_tokens;
812 self
813 }
814
815 fn http_client(&self) -> Client {
816 zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
817 "model_provider.bedrock",
818 120,
819 10,
820 )
821 }
822
823 fn encode_model_path(model_id: &str) -> String {
827 model_id.replace(':', "%3A")
828 }
829
830 fn resolve_region() -> String {
832 env_optional("AWS_REGION")
833 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
834 .unwrap_or_else(|| DEFAULT_REGION.to_string())
835 }
836
837 fn endpoint_url(region: &str, model_id: &str) -> String {
839 format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse")
840 }
841
842 fn canonical_uri(model_id: &str) -> String {
846 let encoded = Self::encode_model_path(model_id);
847 format!("/model/{encoded}/converse")
848 }
849
850 fn cached_credentials(&self) -> Option<AwsCredentials> {
852 let cache = self.cred_cache.lock().ok()?;
853 let creds = cache.as_ref()?;
854 if creds.is_expired() {
855 return None;
856 }
857 Some(creds.clone())
858 }
859
860 fn cache_credentials(&self, creds: &AwsCredentials) {
862 if let Ok(mut cache) = self.cred_cache.lock() {
863 *cache = Some(creds.clone());
864 }
865 }
866
867 async fn resolve_auth(&self) -> anyhow::Result<BedrockAuth> {
869 if let Some(ref auth) = self.auth {
871 match auth {
872 BedrockAuth::BearerToken(token) => {
873 return Ok(BedrockAuth::BearerToken(token.clone()));
874 }
875 BedrockAuth::SigV4(_) => {
876 if let Some(creds) = self.cached_credentials() {
877 return Ok(BedrockAuth::SigV4(creds));
878 }
879 }
880 }
881 }
882 if let Some(token) = env_optional("BEDROCK_API_KEY") {
884 return Ok(BedrockAuth::BearerToken(token));
885 }
886 if let Ok(creds) = AwsCredentials::from_env() {
888 return Ok(BedrockAuth::SigV4(creds));
889 }
890 if let Ok(creds) = AwsCredentials::from_credential_process() {
891 self.cache_credentials(&creds);
892 return Ok(BedrockAuth::SigV4(creds));
893 }
894 Ok(BedrockAuth::SigV4(AwsCredentials::from_imds().await?))
895 }
896
897 fn should_cache_system(text: &str) -> bool {
901 text.len() > 3072
902 }
903
904 fn should_cache_conversation(messages: &[ChatMessage]) -> bool {
906 messages.iter().filter(|m| m.role != "system").count() > 4
907 }
908
909 fn convert_messages(
912 messages: &[ChatMessage],
913 ) -> (Option<Vec<SystemBlock>>, Vec<ConverseMessage>) {
914 let mut system_blocks = Vec::new();
915 let mut converse_messages = Vec::new();
916
917 for msg in messages {
918 match msg.role.as_str() {
919 "system" => {
920 if system_blocks.is_empty() {
921 system_blocks.push(SystemBlock::Text(TextBlock {
922 text: msg.content.clone(),
923 }));
924 }
925 }
926 "assistant" => {
927 if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) {
928 converse_messages.push(ConverseMessage {
929 role: "assistant".to_string(),
930 content: blocks,
931 });
932 } else {
933 let text = if msg.content.trim().is_empty() {
938 "(empty response)".to_string()
939 } else {
940 msg.content.clone()
941 };
942 converse_messages.push(ConverseMessage {
943 role: "assistant".to_string(),
944 content: vec![ContentBlock::Text(TextBlock { text })],
945 });
946 }
947 }
948 "tool" => {
949 let tool_result_msg = Self::parse_tool_result_message(&msg.content)
950 .unwrap_or_else(|| {
951 let tool_use_id = Self::extract_tool_call_id(&msg.content)
955 .or_else(|| Self::last_pending_tool_use_id(&converse_messages))
956 .unwrap_or_else(|| "unknown".to_string());
957
958 ::zeroclaw_log::record!(
959 WARN,
960 ::zeroclaw_log::Event::new(
961 module_path!(),
962 ::zeroclaw_log::Action::Note
963 )
964 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
965 &format!(
966 "Failed to parse tool result message, creating error \
967 toolResult for tool_use_id={}",
968 tool_use_id
969 )
970 );
971
972 ConverseMessage {
973 role: "user".to_string(),
974 content: vec![ContentBlock::ToolResult(ToolResultWrapper {
975 tool_result: ToolResultBlock {
976 tool_use_id,
977 content: vec![ToolResultContent {
978 text: msg.content.clone(),
979 }],
980 status: "error".to_string(),
981 },
982 })],
983 }
984 });
985
986 if let Some(last) = converse_messages.last_mut()
990 && last.role == "user"
991 && last
992 .content
993 .iter()
994 .all(|b| matches!(b, ContentBlock::ToolResult(_)))
995 {
996 last.content.extend(tool_result_msg.content);
997 continue;
998 }
999 converse_messages.push(tool_result_msg);
1000 }
1001 _ => {
1002 let content_blocks = Self::parse_user_content_blocks(&msg.content);
1003 converse_messages.push(ConverseMessage {
1004 role: "user".to_string(),
1005 content: content_blocks,
1006 });
1007 }
1008 }
1009 }
1010
1011 let system = if system_blocks.is_empty() {
1012 None
1013 } else {
1014 Some(system_blocks)
1015 };
1016 (system, converse_messages)
1017 }
1018
1019 fn sanitize_empty_content_blocks(messages: &mut [ConverseMessage]) {
1027 for msg in messages.iter_mut() {
1028 msg.content.retain(|block| match block {
1029 ContentBlock::Text(tb) => !tb.text.trim().is_empty(),
1030 _ => true,
1031 });
1032 if msg.content.is_empty() {
1033 msg.content.push(ContentBlock::Text(TextBlock {
1034 text: "(empty)".to_string(),
1035 }));
1036 }
1037 }
1038 }
1039
1040 fn extract_tool_call_id(content: &str) -> Option<String> {
1042 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
1043 value
1044 .get("tool_call_id")
1045 .or_else(|| value.get("tool_use_id"))
1046 .or_else(|| value.get("toolUseId"))
1047 .and_then(serde_json::Value::as_str)
1048 .map(String::from)
1049 }
1050
1051 fn last_pending_tool_use_id(converse_messages: &[ConverseMessage]) -> Option<String> {
1057 let last_assistant = converse_messages
1058 .iter()
1059 .rev()
1060 .find(|m| m.role == "assistant")?;
1061
1062 let tool_use_ids: Vec<&str> = last_assistant
1063 .content
1064 .iter()
1065 .filter_map(|b| match b {
1066 ContentBlock::ToolUse(wrapper) => Some(wrapper.tool_use.tool_use_id.as_str()),
1067 _ => None,
1068 })
1069 .collect();
1070
1071 let answered_ids: Vec<&str> = converse_messages
1072 .iter()
1073 .rev()
1074 .take_while(|m| m.role == "user")
1075 .flat_map(|m| m.content.iter())
1076 .filter_map(|b| match b {
1077 ContentBlock::ToolResult(wrapper) => Some(wrapper.tool_result.tool_use_id.as_str()),
1078 _ => None,
1079 })
1080 .collect();
1081
1082 tool_use_ids
1083 .into_iter()
1084 .find(|id| !answered_ids.contains(id))
1085 .map(String::from)
1086 }
1087
1088 fn parse_user_content_blocks(content: &str) -> Vec<ContentBlock> {
1090 let mut blocks: Vec<ContentBlock> = Vec::new();
1091 let mut remaining = content;
1092 let has_image = content.contains("[IMAGE:");
1093 ::zeroclaw_log::record!(
1094 INFO,
1095 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
1096 &format!(
1097 "parse_user_content_blocks called, len={}, has_image={}",
1098 content.len(),
1099 has_image
1100 )
1101 );
1102
1103 while let Some(start) = remaining.find("[IMAGE:") {
1104 let text_before = &remaining[..start];
1106 if !text_before.trim().is_empty() {
1107 blocks.push(ContentBlock::Text(TextBlock {
1108 text: text_before.to_string(),
1109 }));
1110 }
1111
1112 let after = &remaining[start + 7..]; if let Some(end) = after.find(']') {
1114 let src = &after[..end];
1115 remaining = &after[end + 1..];
1116
1117 if let Some(rest) = src.strip_prefix("data:")
1119 && let Some(semi) = rest.find(';')
1120 {
1121 let mime = &rest[..semi];
1122 let after_semi = &rest[semi + 1..];
1123 if let Some(b64) = after_semi.strip_prefix("base64,") {
1124 let format = match mime {
1125 "image/png" => "png",
1126 "image/gif" => "gif",
1127 "image/webp" => "webp",
1128 _ => "jpeg",
1129 };
1130 blocks.push(ContentBlock::Image(ImageWrapper {
1131 image: ImageBlock {
1132 format: format.to_string(),
1133 source: ImageSource {
1134 bytes: b64.to_string(),
1135 },
1136 },
1137 }));
1138 continue;
1139 }
1140 }
1141 blocks.push(ContentBlock::Text(TextBlock {
1143 text: format!("[image: {}]", src),
1144 }));
1145 } else {
1146 blocks.push(ContentBlock::Text(TextBlock {
1148 text: remaining.to_string(),
1149 }));
1150 break;
1151 }
1152 }
1153
1154 if !remaining.trim().is_empty() {
1156 blocks.push(ContentBlock::Text(TextBlock {
1157 text: remaining.to_string(),
1158 }));
1159 }
1160
1161 if blocks.is_empty() {
1162 let fallback = if content.trim().is_empty() {
1163 "(empty)".to_string()
1164 } else {
1165 content.to_string()
1166 };
1167 blocks.push(ContentBlock::Text(TextBlock { text: fallback }));
1168 }
1169
1170 blocks
1171 }
1172
1173 fn parse_assistant_tool_call_message(content: &str) -> Option<Vec<ContentBlock>> {
1175 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
1176 let tool_calls = value
1177 .get("tool_calls")
1178 .and_then(|v| serde_json::from_value::<Vec<ProviderToolCall>>(v.clone()).ok())?;
1179
1180 let mut blocks = Vec::new();
1181
1182 if let Some(reasoning) = value
1187 .get("reasoning_content")
1188 .and_then(serde_json::Value::as_str)
1189 .filter(|r| !r.is_empty())
1190 {
1191 for part in reasoning.split('\n') {
1193 if let Ok(block) = serde_json::from_str::<serde_json::Value>(part) {
1194 let text = block
1195 .get("text")
1196 .and_then(|t| t.as_str())
1197 .unwrap_or("")
1198 .to_string();
1199 let signature = block
1200 .get("signature")
1201 .and_then(|s| s.as_str())
1202 .filter(|s| !s.is_empty())
1203 .map(|s| s.to_string());
1204 blocks.push(ContentBlock::ReasoningContent(ReasoningContentOutWrapper {
1205 reasoning_content: ReasoningContentOutBlock {
1206 reasoning_text: ReasoningTextOutField { text, signature },
1207 },
1208 }));
1209 }
1210 }
1211 }
1212
1213 if let Some(text) = value
1214 .get("content")
1215 .and_then(serde_json::Value::as_str)
1216 .map(str::trim)
1217 .filter(|t| !t.is_empty())
1218 {
1219 blocks.push(ContentBlock::Text(TextBlock {
1220 text: text.to_string(),
1221 }));
1222 }
1223 for call in tool_calls {
1224 let input = serde_json::from_str::<serde_json::Value>(&call.arguments)
1225 .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
1226 blocks.push(ContentBlock::ToolUse(ToolUseWrapper {
1227 tool_use: ToolUseBlock {
1228 tool_use_id: call.id,
1229 name: call.name,
1230 input,
1231 },
1232 }));
1233 }
1234 Some(blocks)
1235 }
1236
1237 fn parse_tool_result_message(content: &str) -> Option<ConverseMessage> {
1239 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
1240 let tool_use_id = value
1241 .get("tool_call_id")
1242 .or_else(|| value.get("tool_use_id"))
1243 .or_else(|| value.get("toolUseId"))
1244 .and_then(serde_json::Value::as_str)?
1245 .to_string();
1246 let result = value
1247 .get("content")
1248 .and_then(serde_json::Value::as_str)
1249 .unwrap_or("")
1250 .to_string();
1251 Some(ConverseMessage {
1252 role: "user".to_string(),
1253 content: vec![ContentBlock::ToolResult(ToolResultWrapper {
1254 tool_result: ToolResultBlock {
1255 tool_use_id,
1256 content: vec![ToolResultContent { text: result }],
1257 status: "success".to_string(),
1258 },
1259 })],
1260 })
1261 }
1262
1263 fn convert_tools_to_converse(tools: Option<&[ToolSpec]>) -> Option<ToolConfig> {
1266 let items = tools?;
1267 if items.is_empty() {
1268 return None;
1269 }
1270 let tool_defs: Vec<ToolDefinition> = items
1271 .iter()
1272 .map(|tool| ToolDefinition {
1273 tool_spec: ToolSpecDef {
1274 name: tool.name.clone(),
1275 description: tool.description.clone(),
1276 input_schema: InputSchema {
1277 json: tool.parameters.clone(),
1278 },
1279 },
1280 })
1281 .collect();
1282 Some(ToolConfig { tools: tool_defs })
1283 }
1284
1285 fn parse_converse_response(response: ConverseResponse) -> ProviderChatResponse {
1288 let mut text_parts = Vec::new();
1289 let mut thinking_parts = Vec::new();
1290 let mut tool_calls = Vec::new();
1291
1292 let usage = response.usage.map(|u| TokenUsage {
1293 input_tokens: u.input_tokens,
1294 output_tokens: u.output_tokens,
1295 cached_input_tokens: None,
1296 });
1297
1298 if let Some(output) = response.output
1299 && let Some(message) = output.message
1300 {
1301 for block in message.content {
1302 match block {
1303 ResponseContentBlock::Text(tb) => {
1304 let trimmed = tb.text.trim().to_string();
1305 if !trimmed.is_empty() {
1306 text_parts.push(trimmed);
1307 }
1308 }
1309 ResponseContentBlock::ReasoningContent(wrapper) => {
1310 if let Some(reasoning_text) = wrapper.reasoning_content.reasoning_text {
1311 let block = serde_json::json!({
1313 "text": reasoning_text.text.as_deref().unwrap_or(""),
1314 "signature": reasoning_text.signature.as_deref().unwrap_or(""),
1315 });
1316 thinking_parts.push(block.to_string());
1317 }
1318 }
1319 ResponseContentBlock::ToolUse(wrapper) => {
1320 if !wrapper.tool_use.name.is_empty() {
1321 tool_calls.push(ProviderToolCall {
1322 id: wrapper.tool_use.tool_use_id,
1323 name: wrapper.tool_use.name,
1324 arguments: wrapper.tool_use.input.to_string(),
1325 extra_content: None,
1326 });
1327 }
1328 }
1329 ResponseContentBlock::Other(_) => {}
1330 }
1331 }
1332 }
1333
1334 let reasoning_content = if thinking_parts.is_empty() {
1335 None
1336 } else {
1337 Some(thinking_parts.join("\n"))
1338 };
1339
1340 ProviderChatResponse {
1341 text: if text_parts.is_empty() {
1342 None
1343 } else {
1344 Some(text_parts.join("\n"))
1345 },
1346 tool_calls,
1347 usage,
1348 reasoning_content,
1349 }
1350 }
1351
1352 async fn send_converse_request(
1355 &self,
1356 auth: &BedrockAuth,
1357 model: &str,
1358 request_body: &ConverseRequest,
1359 ) -> anyhow::Result<ConverseResponse> {
1360 let payload = serde_json::to_vec(request_body)?;
1361
1362 if let Ok(debug_val) = serde_json::from_slice::<serde_json::Value>(&payload)
1364 && let Some(msgs) = debug_val.get("messages").and_then(|m| m.as_array())
1365 {
1366 for msg in msgs {
1367 if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
1368 for block in content {
1369 if block.get("image").is_some() {
1370 let mut b = block.clone();
1371 if let Some(img) = b.get_mut("image")
1372 && let Some(src) = img.get_mut("source")
1373 && let Some(bytes) = src.get_mut("bytes")
1374 && let Some(s) = bytes.as_str()
1375 {
1376 *bytes = serde_json::json!(format!("<base64 {} chars>", s.len()));
1377 }
1378 ::zeroclaw_log::record!(
1379 INFO,
1380 ::zeroclaw_log::Event::new(
1381 module_path!(),
1382 ::zeroclaw_log::Action::Note
1383 ),
1384 &format!(
1385 "Bedrock image block: {}",
1386 serde_json::to_string(&b).unwrap_or_default()
1387 )
1388 );
1389 }
1390 }
1391 }
1392 }
1393 }
1394
1395 let response: reqwest::Response = match auth {
1396 BedrockAuth::BearerToken(token) => {
1397 let region = Self::resolve_region();
1398 let url = Self::endpoint_url(®ion, model);
1399
1400 self.http_client()
1401 .post(&url)
1402 .header("content-type", "application/json")
1403 .header("Authorization", format!("Bearer {token}"))
1404 .body(payload)
1405 .send()
1406 .await?
1407 }
1408 BedrockAuth::SigV4(credentials) => {
1409 let url = Self::endpoint_url(&credentials.region, model);
1410 let canonical_uri = Self::canonical_uri(model);
1411 let now = chrono::Utc::now();
1412 let host = credentials.host();
1413 let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
1414
1415 let mut headers_to_sign = vec![
1416 ("content-type".to_string(), "application/json".to_string()),
1417 ("host".to_string(), host),
1418 ("x-amz-date".to_string(), amz_date.clone()),
1419 ];
1420 if let Some(ref session_token) = credentials.session_token {
1421 headers_to_sign
1422 .push(("x-amz-security-token".to_string(), session_token.clone()));
1423 }
1424 headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0));
1425
1426 let authorization = build_authorization_header(
1427 credentials,
1428 "POST",
1429 &canonical_uri,
1430 "",
1431 &headers_to_sign,
1432 &payload,
1433 &now,
1434 );
1435
1436 let mut request = self
1437 .http_client()
1438 .post(&url)
1439 .header("content-type", "application/json")
1440 .header("x-amz-date", &amz_date)
1441 .header("authorization", &authorization);
1442
1443 if let Some(ref session_token) = credentials.session_token {
1444 request = request.header("x-amz-security-token", session_token);
1445 }
1446
1447 request.body(payload).send().await?
1448 }
1449 };
1450
1451 if !response.status().is_success() {
1452 return Err(super::api_error("Bedrock", response).await);
1453 }
1454
1455 let converse_response: ConverseResponse = response.json().await?;
1456 Ok(converse_response)
1457 }
1458}
1459
1460#[async_trait]
1463impl ModelProvider for BedrockModelProvider {
1464 fn capabilities(&self) -> ProviderCapabilities {
1465 ProviderCapabilities {
1466 native_tool_calling: true,
1467 vision: true,
1468 prompt_caching: false,
1469 extended_thinking: true,
1470 }
1471 }
1472
1473 fn supports_native_tools(&self) -> bool {
1474 true
1475 }
1476
1477 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
1478 let tool_values: Vec<serde_json::Value> = tools
1479 .iter()
1480 .map(|t| {
1481 serde_json::json!({
1482 "toolSpec": {
1483 "name": t.name,
1484 "description": t.description,
1485 "inputSchema": { "json": t.parameters }
1486 }
1487 })
1488 })
1489 .collect();
1490 ToolsPayload::Anthropic { tools: tool_values }
1491 }
1492
1493 async fn chat_with_system(
1494 &self,
1495 system_prompt: Option<&str>,
1496 message: &str,
1497 model: &str,
1498 temperature: Option<f64>,
1499 ) -> anyhow::Result<String> {
1500 let auth = self.resolve_auth().await?;
1501
1502 let supports_caching = bedrock_model_supports_prompt_caching(model);
1503 let system = system_prompt.map(|text| {
1504 let mut blocks = vec![SystemBlock::Text(TextBlock {
1505 text: text.to_string(),
1506 })];
1507 if supports_caching && Self::should_cache_system(text) {
1508 blocks.push(SystemBlock::CachePoint(CachePointWrapper {
1509 cache_point: CachePoint::default_cache(),
1510 }));
1511 }
1512 blocks
1513 });
1514
1515 let mut messages = vec![ConverseMessage {
1516 role: "user".to_string(),
1517 content: Self::parse_user_content_blocks(message),
1518 }];
1519 Self::sanitize_empty_content_blocks(&mut messages);
1520
1521 let request = ConverseRequest {
1522 system,
1523 messages,
1524 inference_config: Some(InferenceConfig {
1525 max_tokens: self.max_tokens,
1526 temperature,
1527 }),
1528 tool_config: None,
1529 additional_model_request_fields: None,
1530 };
1531
1532 let response = self.send_converse_request(&auth, model, &request).await?;
1533
1534 Self::parse_converse_response(response).text.ok_or_else(|| {
1535 ::zeroclaw_log::record!(
1536 ERROR,
1537 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1538 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
1539 "bedrock: empty text in response"
1540 );
1541 anyhow::Error::msg("No response from Bedrock")
1542 })
1543 }
1544
1545 async fn chat(
1546 &self,
1547 request: ProviderChatRequest<'_>,
1548 model: &str,
1549 temperature: Option<f64>,
1550 ) -> anyhow::Result<ProviderChatResponse> {
1551 let auth = self.resolve_auth().await?;
1552
1553 let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages);
1554
1555 Self::sanitize_empty_content_blocks(&mut converse_messages);
1557
1558 let supports_caching = bedrock_model_supports_prompt_caching(model);
1562
1563 let system = system_blocks.map(|mut blocks| {
1565 let has_large_system = blocks
1566 .iter()
1567 .any(|b| matches!(b, SystemBlock::Text(tb) if Self::should_cache_system(&tb.text)));
1568 if supports_caching && has_large_system {
1569 blocks.push(SystemBlock::CachePoint(CachePointWrapper {
1570 cache_point: CachePoint::default_cache(),
1571 }));
1572 }
1573 blocks
1574 });
1575
1576 if supports_caching
1578 && Self::should_cache_conversation(request.messages)
1579 && let Some(last_msg) = converse_messages.last_mut()
1580 {
1581 last_msg
1582 .content
1583 .push(ContentBlock::CachePointBlock(CachePointWrapper {
1584 cache_point: CachePoint::default_cache(),
1585 }));
1586 }
1587
1588 let tool_config = Self::convert_tools_to_converse(request.tools);
1589
1590 let (effective_temperature, additional_fields, effective_max_tokens) = match request
1594 .thinking
1595 {
1596 Some(params) if bedrock_model_supports_native_thinking(model) => {
1597 ::zeroclaw_log::record!(
1598 INFO,
1599 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1600 .with_attrs(::serde_json::json!({"budget_tokens": params.budget_tokens})),
1601 "Bedrock native extended thinking enabled; forcing temperature=1.0"
1602 );
1603 let fields = serde_json::json!({
1604 "thinking": {
1605 "type": "enabled",
1606 "budget_tokens": params.budget_tokens
1607 }
1608 });
1609 let min_required = params.budget_tokens + 1;
1610 let max_tokens = self.max_tokens.max(min_required);
1611 (Some(1.0), Some(fields), max_tokens)
1612 }
1613 Some(_) => {
1614 ::zeroclaw_log::record!(
1615 WARN,
1616 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1617 .with_attrs(::serde_json::json!({"model": model})),
1618 "Native extended thinking requested but model only supports adaptive thinking; falling back to prompt-based reasoning"
1619 );
1620 (temperature, None, self.max_tokens)
1621 }
1622 None => (temperature, None, self.max_tokens),
1623 };
1624
1625 let converse_request = ConverseRequest {
1626 system,
1627 messages: converse_messages,
1628 inference_config: Some(InferenceConfig {
1629 max_tokens: effective_max_tokens,
1630 temperature: effective_temperature,
1631 }),
1632 tool_config,
1633 additional_model_request_fields: additional_fields,
1634 };
1635
1636 let response = self
1637 .send_converse_request(&auth, model, &converse_request)
1638 .await?;
1639
1640 Ok(Self::parse_converse_response(response))
1641 }
1642
1643 async fn warmup(&self) -> anyhow::Result<()> {
1644 let region = match self.auth {
1645 Some(BedrockAuth::SigV4(ref creds)) => creds.region.clone(),
1646 Some(BedrockAuth::BearerToken(_)) => Self::resolve_region(),
1647 None => return Ok(()),
1648 };
1649 let url = format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/");
1650 let _ = self.http_client().get(&url).send().await;
1651 Ok(())
1652 }
1653}
1654
1655impl ::zeroclaw_api::attribution::Attributable for BedrockModelProvider {
1658 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1659 ::zeroclaw_api::attribution::Role::Provider(
1660 ::zeroclaw_api::attribution::ProviderKind::Model(
1661 ::zeroclaw_api::attribution::ModelProviderKind::Bedrock,
1662 ),
1663 )
1664 }
1665 fn alias(&self) -> &str {
1666 &self.alias
1667 }
1668}
1669
1670#[cfg(test)]
1671mod tests {
1672 use super::*;
1673 use crate::test_util::{EnvGuard, env_lock};
1674 use crate::traits::ChatMessage;
1675
1676 #[test]
1679 fn sha256_hex_empty_string() {
1680 assert_eq!(
1682 sha256_hex(b""),
1683 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
1684 );
1685 }
1686
1687 #[test]
1688 fn sha256_hex_known_input() {
1689 assert_eq!(
1691 sha256_hex(b"hello"),
1692 "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
1693 );
1694 }
1695
1696 const TEST_VECTOR_SECRET: &str = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY";
1698
1699 #[test]
1700 fn hmac_sha256_known_input() {
1701 let test_key: &[u8] = b"key";
1702 let result = hmac_sha256(test_key, b"message");
1703 assert_eq!(
1704 hex::encode(&result),
1705 "6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a"
1706 );
1707 }
1708
1709 #[test]
1710 fn derive_signing_key_structure() {
1711 let key = derive_signing_key(TEST_VECTOR_SECRET, "20150830", "us-east-1", "iam");
1713 assert_eq!(key.len(), 32);
1714 }
1715
1716 #[test]
1717 fn derive_signing_key_known_test_vector() {
1718 let key = derive_signing_key(TEST_VECTOR_SECRET, "20150830", "us-east-1", "iam");
1720 assert_eq!(
1721 hex::encode(&key),
1722 "c4afb1cc5771d871763a393e44b703571b55cc28424d1a5e86da6ed3c154a4b9"
1723 );
1724 }
1725
1726 #[test]
1727 fn build_authorization_header_format() {
1728 let credentials = AwsCredentials {
1729 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1730 secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
1731 session_token: None,
1732 region: "us-east-1".to_string(),
1733 expires_at: None,
1734 };
1735
1736 let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z")
1737 .unwrap()
1738 .with_timezone(&chrono::Utc);
1739
1740 let headers = vec![
1741 ("content-type".to_string(), "application/json".to_string()),
1742 (
1743 "host".to_string(),
1744 "bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1745 ),
1746 ("x-amz-date".to_string(), "20240115T120000Z".to_string()),
1747 ];
1748
1749 let auth = build_authorization_header(
1750 &credentials,
1751 "POST",
1752 "/model/anthropic.claude-3-sonnet/converse",
1753 "",
1754 &headers,
1755 b"{}",
1756 ×tamp,
1757 );
1758
1759 assert!(auth.starts_with("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/"));
1761 assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
1762 assert!(auth.contains("Signature="));
1763 assert!(auth.contains("/us-east-1/bedrock/aws4_request"));
1764 }
1765
1766 #[test]
1767 fn build_authorization_header_includes_security_token_in_signed_headers() {
1768 let credentials = AwsCredentials {
1769 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1770 secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
1771 session_token: Some("session-token-value".to_string()),
1772 region: "us-east-1".to_string(),
1773 expires_at: None,
1774 };
1775
1776 let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z")
1777 .unwrap()
1778 .with_timezone(&chrono::Utc);
1779
1780 let headers = vec![
1781 ("content-type".to_string(), "application/json".to_string()),
1782 (
1783 "host".to_string(),
1784 "bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1785 ),
1786 ("x-amz-date".to_string(), "20240115T120000Z".to_string()),
1787 (
1788 "x-amz-security-token".to_string(),
1789 "session-token-value".to_string(),
1790 ),
1791 ];
1792
1793 let auth = build_authorization_header(
1794 &credentials,
1795 "POST",
1796 "/model/test-model/converse",
1797 "",
1798 &headers,
1799 b"{}",
1800 ×tamp,
1801 );
1802
1803 assert!(auth.contains("x-amz-security-token"));
1804 }
1805
1806 #[test]
1809 fn credentials_host_formats_correctly() {
1810 let creds = AwsCredentials {
1811 access_key_id: "AKID".to_string(),
1812 secret_access_key: "secret".to_string(),
1813 session_token: None,
1814 region: "us-west-2".to_string(),
1815 expires_at: None,
1816 };
1817 assert_eq!(creds.host(), "bedrock-runtime.us-west-2.amazonaws.com");
1818 }
1819
1820 #[test]
1823 fn creates_without_credentials() {
1824 let _provider = BedrockModelProvider::new("test");
1826 }
1827
1828 #[tokio::test]
1829 #[allow(clippy::await_holding_lock)]
1830 async fn chat_fails_without_credentials() {
1831 let _env_lock = env_lock();
1832 let _ak = EnvGuard::set("AWS_ACCESS_KEY_ID", None);
1833 let _sk = EnvGuard::set("AWS_SECRET_ACCESS_KEY", None);
1834 let _bearer = EnvGuard::set("BEDROCK_API_KEY", None);
1835 let _config = EnvGuard::set("AWS_CONFIG_FILE", Some("/dev/null"));
1836 let model_provider = BedrockModelProvider {
1837 alias: "test".to_string(),
1838 auth: None,
1839 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
1840 cred_cache: Mutex::new(None),
1841 };
1842 let result = model_provider
1843 .chat_with_system(None, "hello", "anthropic.claude-sonnet-4-6", Some(0.7))
1844 .await;
1845 assert!(result.is_err());
1846 let err = result.unwrap_err().to_string();
1847 assert!(
1848 err.contains("credentials not set")
1849 || err.contains("169.254.169.254")
1850 || err.to_lowercase().contains("credential")
1851 || err.to_lowercase().contains("builder error"),
1852 "Expected missing-credentials style error, got: {err}"
1853 );
1854 }
1855
1856 #[test]
1859 fn creates_with_bearer_token() {
1860 let model_provider = BedrockModelProvider::with_bearer_token("test", "test-api-key");
1861 assert!(model_provider.auth.is_some());
1862 assert!(
1863 matches!(model_provider.auth, Some(BedrockAuth::BearerToken(ref t)) if t == "test-api-key")
1864 );
1865 }
1866
1867 #[test]
1868 fn bearer_token_from_env() {
1869 let _env_lock = env_lock();
1870 let _guard = EnvGuard::set("BEDROCK_API_KEY", Some("env-bearer-token"));
1871 let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", None);
1873 let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", None);
1874
1875 let model_provider = BedrockModelProvider::new("test");
1876 assert!(matches!(
1877 model_provider.auth,
1878 Some(BedrockAuth::BearerToken(ref t)) if t == "env-bearer-token"
1879 ));
1880 }
1881
1882 #[test]
1883 fn bearer_token_precedence() {
1884 let _env_lock = env_lock();
1885 let _bearer_guard = EnvGuard::set("BEDROCK_API_KEY", Some("bearer-key"));
1886 let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", Some("AKIAEXAMPLE"));
1887 let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", Some("secret"));
1888
1889 let model_provider = BedrockModelProvider::new("test");
1890 assert!(matches!(
1892 model_provider.auth,
1893 Some(BedrockAuth::BearerToken(ref t)) if t == "bearer-key"
1894 ));
1895 }
1896
1897 #[test]
1900 fn endpoint_url_formats_correctly() {
1901 let url = BedrockModelProvider::endpoint_url("us-east-1", "anthropic.claude-sonnet-4-6");
1902 assert_eq!(
1903 url,
1904 "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-6/converse"
1905 );
1906 }
1907
1908 #[test]
1909 fn endpoint_url_keeps_raw_colon() {
1910 let url = BedrockModelProvider::endpoint_url(
1912 "us-west-2",
1913 "anthropic.claude-3-5-haiku-20241022-v1:0",
1914 );
1915 assert!(url.contains("/model/anthropic.claude-3-5-haiku-20241022-v1:0/converse"));
1916 }
1917
1918 #[test]
1919 fn canonical_uri_encodes_colon() {
1920 let uri = BedrockModelProvider::canonical_uri("anthropic.claude-3-5-haiku-20241022-v1:0");
1922 assert_eq!(
1923 uri,
1924 "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse"
1925 );
1926 }
1927
1928 #[test]
1929 fn canonical_uri_no_colon_unchanged() {
1930 let uri = BedrockModelProvider::canonical_uri("anthropic.claude-sonnet-4-6");
1931 assert_eq!(uri, "/model/anthropic.claude-sonnet-4-6/converse");
1932 }
1933
1934 #[test]
1937 fn convert_messages_system_extracted() {
1938 let messages = vec![
1939 ChatMessage::system("You are helpful"),
1940 ChatMessage::user("Hello"),
1941 ];
1942 let (system, msgs) = BedrockModelProvider::convert_messages(&messages);
1943 assert!(system.is_some());
1944 let system_blocks = system.unwrap();
1945 assert_eq!(system_blocks.len(), 1);
1946 assert_eq!(msgs.len(), 1);
1947 assert_eq!(msgs[0].role, "user");
1948 }
1949
1950 #[test]
1951 fn convert_messages_user_and_assistant() {
1952 let messages = vec![
1953 ChatMessage::user("Hello"),
1954 ChatMessage::assistant("Hi there"),
1955 ];
1956 let (system, msgs) = BedrockModelProvider::convert_messages(&messages);
1957 assert!(system.is_none());
1958 assert_eq!(msgs.len(), 2);
1959 assert_eq!(msgs[0].role, "user");
1960 assert_eq!(msgs[1].role, "assistant");
1961 }
1962
1963 #[test]
1964 fn convert_messages_tool_role_to_tool_result() {
1965 let tool_json = r#"{"tool_call_id": "call_123", "content": "Result data"}"#;
1966 let messages = vec![ChatMessage::tool(tool_json)];
1967 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
1968 assert_eq!(msgs.len(), 1);
1969 assert_eq!(msgs[0].role, "user");
1970 assert!(matches!(msgs[0].content[0], ContentBlock::ToolResult(_)));
1971 }
1972
1973 #[test]
1974 fn convert_messages_assistant_tool_calls_parsed() {
1975 let tool_call_json = r#"{"content": "Let me check", "tool_calls": [{"id": "call_1", "name": "shell", "arguments": "{\"command\":\"ls\"}"}]}"#;
1976 let messages = vec![ChatMessage::assistant(tool_call_json)];
1977 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
1978 assert_eq!(msgs.len(), 1);
1979 assert_eq!(msgs[0].role, "assistant");
1980 assert_eq!(msgs[0].content.len(), 2);
1981 assert!(matches!(msgs[0].content[0], ContentBlock::Text(_)));
1982 assert!(matches!(msgs[0].content[1], ContentBlock::ToolUse(_)));
1983 }
1984
1985 #[test]
1986 fn convert_messages_plain_assistant_text() {
1987 let messages = vec![ChatMessage::assistant("Just text")];
1988 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
1989 assert_eq!(msgs.len(), 1);
1990 assert!(matches!(msgs[0].content[0], ContentBlock::Text(_)));
1991 }
1992
1993 #[test]
1996 fn should_cache_system_small_prompt() {
1997 assert!(!BedrockModelProvider::should_cache_system("Short prompt"));
1998 }
1999
2000 #[test]
2001 fn should_cache_system_large_prompt() {
2002 let large = "a".repeat(3073);
2003 assert!(BedrockModelProvider::should_cache_system(&large));
2004 }
2005
2006 #[test]
2007 fn should_cache_system_boundary() {
2008 assert!(!BedrockModelProvider::should_cache_system(
2009 &"a".repeat(3072)
2010 ));
2011 assert!(BedrockModelProvider::should_cache_system(&"a".repeat(3073)));
2012 }
2013
2014 #[test]
2015 fn should_cache_conversation_short() {
2016 let messages = vec![
2017 ChatMessage::system("System"),
2018 ChatMessage::user("Hello"),
2019 ChatMessage::assistant("Hi"),
2020 ];
2021 assert!(!BedrockModelProvider::should_cache_conversation(&messages));
2022 }
2023
2024 #[test]
2025 fn should_cache_conversation_long() {
2026 let mut messages = vec![ChatMessage::system("System")];
2027 for i in 0..5 {
2028 messages.push(ChatMessage {
2029 role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(),
2030 content: format!("Message {i}"),
2031 });
2032 }
2033 assert!(BedrockModelProvider::should_cache_conversation(&messages));
2034 }
2035
2036 #[test]
2039 fn convert_tools_to_converse_formats_correctly() {
2040 let tools = vec![ToolSpec {
2041 name: "shell".to_string(),
2042 description: "Run commands".to_string(),
2043 parameters: serde_json::json!({"type": "object", "properties": {"command": {"type": "string"}}}),
2044 }];
2045 let config = BedrockModelProvider::convert_tools_to_converse(Some(&tools));
2046 assert!(config.is_some());
2047 let config = config.unwrap();
2048 assert_eq!(config.tools.len(), 1);
2049 assert_eq!(config.tools[0].tool_spec.name, "shell");
2050 }
2051
2052 #[test]
2053 fn convert_tools_to_converse_empty_returns_none() {
2054 assert!(BedrockModelProvider::convert_tools_to_converse(Some(&[])).is_none());
2055 assert!(BedrockModelProvider::convert_tools_to_converse(None).is_none());
2056 }
2057
2058 #[test]
2061 fn converse_request_serializes_without_system() {
2062 let req = ConverseRequest {
2063 system: None,
2064 messages: vec![ConverseMessage {
2065 role: "user".to_string(),
2066 content: vec![ContentBlock::Text(TextBlock {
2067 text: "Hello".to_string(),
2068 })],
2069 }],
2070 inference_config: Some(InferenceConfig {
2071 max_tokens: 4096,
2072 temperature: Some(0.7),
2073 }),
2074 tool_config: None,
2075 additional_model_request_fields: None,
2076 };
2077 let json = serde_json::to_string(&req).unwrap();
2078 assert!(!json.contains("system"));
2079 assert!(json.contains("Hello"));
2080 assert!(json.contains("maxTokens"));
2081 }
2082
2083 #[test]
2084 fn bedrock_model_supports_native_thinking_excludes_opus_4_7() {
2085 assert!(!bedrock_model_supports_native_thinking(
2088 "us.anthropic.claude-opus-4-7"
2089 ));
2090 assert!(!bedrock_model_supports_native_thinking(
2091 "anthropic.claude-opus-4-7-v1:0"
2092 ));
2093 }
2094
2095 #[test]
2096 fn bedrock_model_supports_native_thinking_allows_other_models() {
2097 assert!(bedrock_model_supports_native_thinking(
2098 "us.anthropic.claude-opus-4-6-v1"
2099 ));
2100 assert!(bedrock_model_supports_native_thinking(
2101 "us.anthropic.claude-sonnet-4-6-v1"
2102 ));
2103 assert!(bedrock_model_supports_native_thinking(
2104 "us.anthropic.claude-haiku-4-5-v1"
2105 ));
2106 }
2107
2108 #[test]
2109 fn prompt_caching_supported_for_claude_and_nova() {
2110 for model in [
2111 "anthropic.claude-3-5-sonnet-20241022-v2:0",
2112 "us.anthropic.claude-sonnet-4-6-v1",
2113 "anthropic.claude-3-7-sonnet-20250219-v1:0",
2114 "amazon.nova-pro-v1:0",
2115 "us.amazon.nova-lite-v1:0",
2116 ] {
2117 assert!(
2118 bedrock_model_supports_prompt_caching(model),
2119 "expected prompt caching support for {model}"
2120 );
2121 }
2122 }
2123
2124 #[test]
2125 fn prompt_caching_unsupported_for_other_families() {
2126 for model in [
2129 "qwen.qwen3-coder-next",
2130 "meta.llama3-1-70b-instruct-v1:0",
2131 "mistral.mistral-large-2407-v1:0",
2132 "deepseek.r1-v1:0",
2133 ] {
2134 assert!(
2135 !bedrock_model_supports_prompt_caching(model),
2136 "expected NO prompt caching support for {model}"
2137 );
2138 }
2139 }
2140
2141 #[test]
2142 fn prompt_caching_match_is_case_insensitive() {
2143 assert!(bedrock_model_supports_prompt_caching("ANTHROPIC.CLAUDE-X"));
2144 assert!(bedrock_model_supports_prompt_caching("Amazon.Nova-Pro"));
2145 assert!(!bedrock_model_supports_prompt_caching("QWEN.qwen3"));
2146 }
2147
2148 #[test]
2149 fn inference_config_serializes_without_temperature_when_none() {
2150 let cfg = InferenceConfig {
2151 max_tokens: 4096,
2152 temperature: None,
2153 };
2154 let json = serde_json::to_string(&cfg).unwrap();
2155 assert!(json.contains("maxTokens"));
2156 assert!(
2157 !json.contains("temperature"),
2158 "expected temperature to be omitted, got: {json}"
2159 );
2160 }
2161
2162 #[test]
2163 fn inference_config_serializes_with_temperature_when_some() {
2164 let cfg = InferenceConfig {
2165 max_tokens: 4096,
2166 temperature: Some(0.7),
2167 };
2168 let json = serde_json::to_string(&cfg).unwrap();
2169 assert!(json.contains("maxTokens"));
2170 assert!(
2171 json.contains("temperature"),
2172 "expected temperature to be present, got: {json}"
2173 );
2174 }
2175
2176 #[test]
2177 fn converse_response_deserializes_text() {
2178 let json = r#"{
2179 "output": {
2180 "message": {
2181 "role": "assistant",
2182 "content": [{"text": "Hello from Bedrock"}]
2183 }
2184 },
2185 "stopReason": "end_turn"
2186 }"#;
2187 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2188 let parsed = BedrockModelProvider::parse_converse_response(resp);
2189 assert_eq!(parsed.text.as_deref(), Some("Hello from Bedrock"));
2190 assert!(parsed.tool_calls.is_empty());
2191 }
2192
2193 #[test]
2194 fn converse_response_deserializes_tool_use() {
2195 let json = r#"{
2196 "output": {
2197 "message": {
2198 "role": "assistant",
2199 "content": [
2200 {"toolUse": {"toolUseId": "call_1", "name": "shell", "input": {"command": "ls"}}}
2201 ]
2202 }
2203 },
2204 "stopReason": "tool_use"
2205 }"#;
2206 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2207 let parsed = BedrockModelProvider::parse_converse_response(resp);
2208 assert!(parsed.text.is_none());
2209 assert_eq!(parsed.tool_calls.len(), 1);
2210 assert_eq!(parsed.tool_calls[0].name, "shell");
2211 assert_eq!(parsed.tool_calls[0].id, "call_1");
2212 }
2213
2214 #[test]
2215 fn converse_response_empty_output() {
2216 let json = r#"{"output": null, "stopReason": null}"#;
2217 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2218 let parsed = BedrockModelProvider::parse_converse_response(resp);
2219 assert!(parsed.text.is_none());
2220 assert!(parsed.tool_calls.is_empty());
2221 }
2222
2223 #[test]
2224 fn content_block_text_serializes_as_flat_string() {
2225 let block = ContentBlock::Text(TextBlock {
2226 text: "Hello".to_string(),
2227 });
2228 let json = serde_json::to_string(&block).unwrap();
2229 assert_eq!(json, r#"{"text":"Hello"}"#);
2231 }
2232
2233 #[test]
2234 fn content_block_tool_use_serializes_with_nested_object() {
2235 let block = ContentBlock::ToolUse(ToolUseWrapper {
2236 tool_use: ToolUseBlock {
2237 tool_use_id: "call_1".to_string(),
2238 name: "shell".to_string(),
2239 input: serde_json::json!({"command": "ls"}),
2240 },
2241 });
2242 let json = serde_json::to_string(&block).unwrap();
2243 assert!(json.contains(r#""toolUse""#));
2244 assert!(json.contains(r#""toolUseId":"call_1""#));
2245 }
2246
2247 #[test]
2248 fn content_block_cache_point_serializes() {
2249 let block = ContentBlock::CachePointBlock(CachePointWrapper {
2250 cache_point: CachePoint::default_cache(),
2251 });
2252 let json = serde_json::to_string(&block).unwrap();
2253 assert_eq!(json, r#"{"cachePoint":{"type":"default"}}"#);
2254 }
2255
2256 #[test]
2257 fn content_block_text_round_trips() {
2258 let original = ContentBlock::Text(TextBlock {
2259 text: "Hello".to_string(),
2260 });
2261 let json = serde_json::to_string(&original).unwrap();
2262 let deserialized: ContentBlock = serde_json::from_str(&json).unwrap();
2263 assert!(matches!(deserialized, ContentBlock::Text(tb) if tb.text == "Hello"));
2264 }
2265
2266 #[test]
2267 fn cache_point_serializes() {
2268 let cp = CachePoint::default_cache();
2269 let json = serde_json::to_string(&cp).unwrap();
2270 assert_eq!(json, r#"{"type":"default"}"#);
2271 }
2272
2273 #[tokio::test]
2274 async fn warmup_without_credentials_is_noop() {
2275 let model_provider = BedrockModelProvider {
2276 alias: "test".to_string(),
2277 auth: None,
2278 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2279 cred_cache: Mutex::new(None),
2280 };
2281 let result = model_provider.warmup().await;
2282 assert!(result.is_ok());
2283 }
2284
2285 #[test]
2286 fn capabilities_reports_native_tool_calling() {
2287 let model_provider = BedrockModelProvider {
2288 alias: "test".to_string(),
2289 auth: None,
2290 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2291 cred_cache: Mutex::new(None),
2292 };
2293 let caps = model_provider.capabilities();
2294 assert!(caps.native_tool_calling);
2295 }
2296
2297 #[test]
2298 fn converse_response_parses_usage() {
2299 let json = r#"{
2300 "output": {"message": {"role": "assistant", "content": [{"text": {"text": "Hello"}}]}},
2301 "usage": {"inputTokens": 500, "outputTokens": 100}
2302 }"#;
2303 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2304 let usage = resp.usage.unwrap();
2305 assert_eq!(usage.input_tokens, Some(500));
2306 assert_eq!(usage.output_tokens, Some(100));
2307 }
2308
2309 #[test]
2310 fn converse_response_parses_without_usage() {
2311 let json = r#"{"output": {"message": {"role": "assistant", "content": []}}}"#;
2312 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2313 assert!(resp.usage.is_none());
2314 }
2315
2316 #[test]
2319 fn fallback_tool_result_emits_tool_result_block_not_text() {
2320 let messages = vec![
2323 ChatMessage::user("do something"),
2324 ChatMessage::assistant(
2325 r#"{"content":"","tool_calls":[{"id":"tool_1","name":"shell","arguments":"{}"}]}"#,
2326 ),
2327 ChatMessage {
2328 role: "tool".to_string(),
2329 content: "not valid json".to_string(),
2330 },
2331 ];
2332 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
2333 let tool_msg = &msgs[2];
2334 assert_eq!(tool_msg.role, "user");
2335 assert!(
2336 matches!(&tool_msg.content[0], ContentBlock::ToolResult(_)),
2337 "Expected ToolResult block, got {:?}",
2338 tool_msg.content[0]
2339 );
2340 }
2341
2342 #[test]
2343 fn fallback_recovers_tool_use_id_from_assistant() {
2344 let messages = vec![
2345 ChatMessage::user("run it"),
2346 ChatMessage::assistant(
2347 r#"{"content":"","tool_calls":[{"id":"tool_abc","name":"shell","arguments":"{}"}]}"#,
2348 ),
2349 ChatMessage {
2350 role: "tool".to_string(),
2351 content: "raw output with no json".to_string(),
2352 },
2353 ];
2354 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
2355 if let ContentBlock::ToolResult(ref wrapper) = msgs[2].content[0] {
2356 assert_eq!(wrapper.tool_result.tool_use_id, "tool_abc");
2357 assert_eq!(wrapper.tool_result.status, "error");
2358 } else {
2359 panic!("Expected ToolResult block");
2360 }
2361 }
2362
2363 #[test]
2364 fn consecutive_tool_results_merged_into_single_message() {
2365 let messages = vec![
2366 ChatMessage::user("do two things"),
2367 ChatMessage::assistant(
2368 r#"{"content":"","tool_calls":[{"id":"t1","name":"a","arguments":"{}"},{"id":"t2","name":"b","arguments":"{}"}]}"#,
2369 ),
2370 ChatMessage::tool(r#"{"tool_call_id":"t1","content":"result 1"}"#),
2371 ChatMessage::tool(r#"{"tool_call_id":"t2","content":"result 2"}"#),
2372 ];
2373 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
2374 assert_eq!(msgs.len(), 3, "Expected 3 messages, got {}", msgs.len());
2376 assert_eq!(msgs[2].role, "user");
2377 assert_eq!(
2378 msgs[2].content.len(),
2379 2,
2380 "Expected 2 tool results in one message"
2381 );
2382 assert!(matches!(&msgs[2].content[0], ContentBlock::ToolResult(_)));
2383 assert!(matches!(&msgs[2].content[1], ContentBlock::ToolResult(_)));
2384 }
2385
2386 #[test]
2387 fn extract_tool_call_id_tries_multiple_field_names() {
2388 assert_eq!(
2389 BedrockModelProvider::extract_tool_call_id(r#"{"tool_call_id":"a"}"#),
2390 Some("a".to_string())
2391 );
2392 assert_eq!(
2393 BedrockModelProvider::extract_tool_call_id(r#"{"tool_use_id":"b"}"#),
2394 Some("b".to_string())
2395 );
2396 assert_eq!(
2397 BedrockModelProvider::extract_tool_call_id(r#"{"toolUseId":"c"}"#),
2398 Some("c".to_string())
2399 );
2400 assert_eq!(
2401 BedrockModelProvider::extract_tool_call_id("not json at all"),
2402 None
2403 );
2404 }
2405
2406 #[test]
2407 fn parse_tool_result_accepts_alternate_id_fields() {
2408 let msg = BedrockModelProvider::parse_tool_result_message(
2409 r#"{"tool_use_id":"x","content":"ok"}"#,
2410 );
2411 assert!(msg.is_some());
2412 if let ContentBlock::ToolResult(ref wrapper) = msg.unwrap().content[0] {
2413 assert_eq!(wrapper.tool_result.tool_use_id, "x");
2414 } else {
2415 panic!("Expected ToolResult");
2416 }
2417 }
2418
2419 #[test]
2420 fn sanitize_removes_empty_text_blocks() {
2421 let mut messages = vec![ConverseMessage {
2422 role: "assistant".to_string(),
2423 content: vec![ContentBlock::Text(TextBlock {
2424 text: String::new(),
2425 })],
2426 }];
2427 BedrockModelProvider::sanitize_empty_content_blocks(&mut messages);
2428 assert_eq!(messages.len(), 1);
2429 if let ContentBlock::Text(ref tb) = messages[0].content[0] {
2430 assert_eq!(tb.text, "(empty)");
2431 } else {
2432 panic!("Expected Text block with placeholder");
2433 }
2434 }
2435
2436 #[test]
2437 fn sanitize_preserves_non_empty_text_blocks() {
2438 let mut messages = vec![ConverseMessage {
2439 role: "user".to_string(),
2440 content: vec![ContentBlock::Text(TextBlock {
2441 text: "Hello".to_string(),
2442 })],
2443 }];
2444 BedrockModelProvider::sanitize_empty_content_blocks(&mut messages);
2445 if let ContentBlock::Text(ref tb) = messages[0].content[0] {
2446 assert_eq!(tb.text, "Hello");
2447 } else {
2448 panic!("Expected preserved Text block");
2449 }
2450 }
2451
2452 #[test]
2453 fn convert_messages_empty_assistant_gets_placeholder() {
2454 let messages = vec![
2455 ChatMessage::user("Hello"),
2456 ChatMessage {
2457 role: "assistant".to_string(),
2458 content: String::new(),
2459 },
2460 ChatMessage::user("Continue"),
2461 ];
2462 let (_, converse) = BedrockModelProvider::convert_messages(&messages);
2463 let assistant_msg = &converse[1];
2464 assert_eq!(assistant_msg.role, "assistant");
2465 if let ContentBlock::Text(ref tb) = assistant_msg.content[0] {
2466 assert!(!tb.text.is_empty(), "Assistant text should not be empty");
2467 } else {
2468 panic!("Expected Text block for assistant message");
2469 }
2470 }
2471
2472 #[test]
2475 fn parse_aws_config_default_profile() {
2476 let config = "\
2477[default]
2478region=us-west-2
2479credential_process=ada credentials print --account=123 --provider=conduit --role=MyRole
2480";
2481 let result = AwsCredentials::parse_aws_config(config, "default");
2482 assert!(result.is_some());
2483 let (cmd, region) = result.unwrap();
2484 assert_eq!(
2485 cmd,
2486 "ada credentials print --account=123 --provider=conduit --role=MyRole"
2487 );
2488 assert_eq!(region.as_deref(), Some("us-west-2"));
2489 }
2490
2491 #[test]
2492 fn parse_aws_config_named_profile() {
2493 let config = "\
2494[default]
2495region=us-east-1
2496
2497[profile myprofile]
2498region=eu-west-1
2499credential_process=aws sso get-role-credentials --profile myprofile
2500";
2501 let result = AwsCredentials::parse_aws_config(config, "myprofile");
2502 assert!(result.is_some());
2503 let (cmd, region) = result.unwrap();
2504 assert!(cmd.contains("myprofile"));
2505 assert_eq!(region.as_deref(), Some("eu-west-1"));
2506 }
2507
2508 #[test]
2509 fn parse_aws_config_missing_credential_process() {
2510 let config = "\
2511[default]
2512region=us-west-2
2513";
2514 let result = AwsCredentials::parse_aws_config(config, "default");
2515 assert!(result.is_none());
2516 }
2517
2518 #[test]
2519 fn parse_aws_config_ignores_comments() {
2520 let config = "\
2521[default]
2522# credential_process=should-be-ignored
2523; credential_process=also-ignored
2524credential_process=real-command
2525";
2526 let result = AwsCredentials::parse_aws_config(config, "default");
2527 assert!(result.is_some());
2528 assert_eq!(result.unwrap().0, "real-command");
2529 }
2530
2531 #[test]
2532 fn parse_aws_config_nonexistent_profile() {
2533 let config = "\
2534[default]
2535credential_process=some-command
2536";
2537 let result = AwsCredentials::parse_aws_config(config, "nonexistent");
2538 assert!(result.is_none());
2539 }
2540
2541 #[test]
2542 fn from_credential_process_parses_json_output() {
2543 let config = "\
2545[default]
2546credential_process=echo '{\"Version\":1,\"AccessKeyId\":\"AKIA\",\"SecretAccessKey\":\"secret\",\"SessionToken\":\"tok\"}'
2547region=ap-southeast-1
2548";
2549 let (cmd, region) = AwsCredentials::parse_aws_config(config, "default").unwrap();
2550 assert!(cmd.starts_with("echo"));
2551 assert_eq!(region.as_deref(), Some("ap-southeast-1"));
2552
2553 let output = std::process::Command::new("sh")
2554 .args(["-c", &cmd])
2555 .output()
2556 .unwrap();
2557 let json: serde_json::Value = serde_json::from_slice(&output.stdout).unwrap();
2558 assert_eq!(json["AccessKeyId"].as_str(), Some("AKIA"));
2559 assert_eq!(json["SecretAccessKey"].as_str(), Some("secret"));
2560 assert_eq!(json["SessionToken"].as_str(), Some("tok"));
2561 }
2562
2563 #[test]
2564 fn env_vars_take_precedence_over_credential_process() {
2565 let _env_lock = env_lock();
2566 let _ak = EnvGuard::set("AWS_ACCESS_KEY_ID", Some("FROM_ENV"));
2567 let _sk = EnvGuard::set("AWS_SECRET_ACCESS_KEY", Some("secret_from_env"));
2568
2569 let creds = AwsCredentials::from_env();
2570 assert!(creds.is_ok());
2571 assert_eq!(creds.unwrap().access_key_id, "FROM_ENV");
2572 }
2573
2574 fn make_creds(expires_at: Option<chrono::DateTime<chrono::Utc>>) -> AwsCredentials {
2577 AwsCredentials {
2578 access_key_id: "AKIA".to_string(),
2579 secret_access_key: "secret".to_string(),
2580 session_token: Some("tok".to_string()),
2581 region: "us-west-2".to_string(),
2582 expires_at,
2583 }
2584 }
2585
2586 #[test]
2587 fn is_expired_returns_false_when_no_expiry() {
2588 let creds = make_creds(None);
2589 assert!(!creds.is_expired());
2590 }
2591
2592 #[test]
2593 fn is_expired_returns_false_when_future() {
2594 let future = chrono::Utc::now() + chrono::Duration::hours(1);
2595 let creds = make_creds(Some(future));
2596 assert!(!creds.is_expired());
2597 }
2598
2599 #[test]
2600 fn is_expired_returns_true_when_past() {
2601 let past = chrono::Utc::now() - chrono::Duration::hours(1);
2602 let creds = make_creds(Some(past));
2603 assert!(creds.is_expired());
2604 }
2605
2606 #[test]
2607 fn is_expired_returns_true_within_skew_window() {
2608 let soon = chrono::Utc::now() + chrono::Duration::seconds(30);
2610 let creds = make_creds(Some(soon));
2611 assert!(creds.is_expired());
2612 }
2613
2614 #[test]
2615 fn cached_credentials_returns_none_when_empty() {
2616 let model_provider = BedrockModelProvider {
2617 alias: "test".to_string(),
2618 auth: None,
2619 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2620 cred_cache: Mutex::new(None),
2621 };
2622 assert!(model_provider.cached_credentials().is_none());
2623 }
2624
2625 #[test]
2626 fn cached_credentials_returns_some_when_valid() {
2627 let future = chrono::Utc::now() + chrono::Duration::hours(1);
2628 let model_provider = BedrockModelProvider {
2629 alias: "test".to_string(),
2630 auth: None,
2631 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2632 cred_cache: Mutex::new(Some(make_creds(Some(future)))),
2633 };
2634 let cached = model_provider.cached_credentials();
2635 assert!(cached.is_some());
2636 assert_eq!(cached.unwrap().access_key_id, "AKIA");
2637 }
2638
2639 #[test]
2640 fn cached_credentials_returns_none_when_expired() {
2641 let past = chrono::Utc::now() - chrono::Duration::hours(1);
2642 let model_provider = BedrockModelProvider {
2643 alias: "test".to_string(),
2644 auth: None,
2645 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2646 cred_cache: Mutex::new(Some(make_creds(Some(past)))),
2647 };
2648 assert!(model_provider.cached_credentials().is_none());
2649 }
2650
2651 #[test]
2652 fn cache_credentials_stores_and_retrieves() {
2653 let future = chrono::Utc::now() + chrono::Duration::hours(1);
2654 let model_provider = BedrockModelProvider {
2655 alias: "test".to_string(),
2656 auth: None,
2657 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2658 cred_cache: Mutex::new(None),
2659 };
2660 assert!(model_provider.cached_credentials().is_none());
2661 model_provider.cache_credentials(&make_creds(Some(future)));
2662 assert!(model_provider.cached_credentials().is_some());
2663 }
2664}