1use crate::traits::{
11 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
12 ModelProvider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
13};
14use async_trait::async_trait;
15use hmac::{Hmac, Mac};
16use reqwest::Client;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19use std::sync::Mutex;
20use zeroclaw_api::tool::ToolSpec;
21
22const ENDPOINT_PREFIX: &str = "bedrock-runtime";
24const SIGNING_SERVICE: &str = "bedrock";
26const DEFAULT_REGION: &str = "us-east-1";
27
28enum BedrockAuth {
32 SigV4(AwsCredentials),
33 BearerToken(String),
34}
35
36#[derive(Clone)]
40struct AwsCredentials {
41 access_key_id: String,
42 secret_access_key: String,
43 session_token: Option<String>,
44 region: String,
45 expires_at: Option<chrono::DateTime<chrono::Utc>>,
48}
49
50impl AwsCredentials {
51 fn from_env() -> anyhow::Result<Self> {
53 let access_key_id = env_required("AWS_ACCESS_KEY_ID")?;
54 let secret_access_key = env_required("AWS_SECRET_ACCESS_KEY")?;
55
56 let session_token = env_optional("AWS_SESSION_TOKEN");
57
58 let region = env_optional("AWS_REGION")
59 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
60 .unwrap_or_else(|| DEFAULT_REGION.to_string());
61
62 Ok(Self {
63 access_key_id,
64 secret_access_key,
65 session_token,
66 region,
67 expires_at: None,
68 })
69 }
70
71 fn parse_aws_config(content: &str, profile: &str) -> Option<(String, Option<String>)> {
74 let target = if profile == "default" {
75 "[default]".to_string()
76 } else {
77 format!("[profile {profile}]")
78 };
79
80 let mut in_section = false;
81 let mut cred_process = None;
82 let mut region = None;
83
84 for line in content.lines() {
85 let trimmed = line.trim();
86 if trimmed.starts_with('[') {
87 in_section = trimmed == target;
88 continue;
89 }
90 if !in_section || trimmed.starts_with('#') || trimmed.starts_with(';') {
91 continue;
92 }
93 if let Some((key, value)) = trimmed.split_once('=') {
94 match key.trim() {
95 "credential_process" => cred_process = Some(value.trim().to_string()),
96 "region" => region = Some(value.trim().to_string()),
97 _ => {}
98 }
99 }
100 }
101 cred_process.map(|cmd| (cmd, region))
102 }
103
104 fn from_credential_process() -> anyhow::Result<Self> {
106 let config_path = std::env::var("AWS_CONFIG_FILE").unwrap_or_else(|_| {
107 let home = std::env::var("HOME").unwrap_or_else(|_| "~".to_string());
108 format!("{home}/.aws/config")
109 });
110 let content = std::fs::read_to_string(&config_path).map_err(|e| {
111 ::zeroclaw_log::record!(
112 ERROR,
113 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
114 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
115 .with_attrs(::serde_json::json!({
116 "config_path": &config_path,
117 "error": format!("{}", e),
118 })),
119 "bedrock: cannot read AWS config file"
120 );
121 anyhow::Error::msg(format!("Cannot read {config_path}: {e}"))
122 })?;
123 let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
124 let (cmd, config_region) = Self::parse_aws_config(&content, &profile).ok_or_else(|| {
125 ::zeroclaw_log::record!(
126 ERROR,
127 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
128 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
129 .with_attrs(::serde_json::json!({"profile": &profile})),
130 "bedrock: no credential_process in AWS profile"
131 );
132 anyhow::Error::msg(format!("No credential_process in [{profile}]"))
133 })?;
134
135 let output = std::process::Command::new("sh")
136 .args(["-c", &cmd])
137 .output()
138 .map_err(|e| {
139 ::zeroclaw_log::record!(
140 ERROR,
141 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
142 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
143 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
144 "bedrock: failed to spawn credential_process"
145 );
146 anyhow::Error::msg(format!("Failed to run credential_process: {e}"))
147 })?;
148 anyhow::ensure!(
149 output.status.success(),
150 "credential_process exited with {}: {}",
151 output.status,
152 String::from_utf8_lossy(&output.stderr).trim()
153 );
154
155 let json: serde_json::Value = serde_json::from_slice(&output.stdout).map_err(|e| {
156 ::zeroclaw_log::record!(
157 ERROR,
158 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
159 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
160 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
161 "bedrock: credential_process output is not valid JSON"
162 );
163 anyhow::Error::msg(format!("credential_process output is not valid JSON: {e}"))
164 })?;
165
166 let access_key_id = json["AccessKeyId"]
167 .as_str()
168 .ok_or_else(|| {
169 ::zeroclaw_log::record!(
170 ERROR,
171 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
172 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
173 .with_attrs(::serde_json::json!({"missing": "AccessKeyId"})),
174 "bedrock: credential_process missing AccessKeyId"
175 );
176 anyhow::Error::msg("Missing AccessKeyId in credential_process output")
177 })?
178 .to_string();
179 let secret_access_key = json["SecretAccessKey"]
180 .as_str()
181 .ok_or_else(|| {
182 ::zeroclaw_log::record!(
183 ERROR,
184 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
185 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
186 .with_attrs(::serde_json::json!({"missing": "SecretAccessKey"})),
187 "bedrock: credential_process missing SecretAccessKey"
188 );
189 anyhow::Error::msg("Missing SecretAccessKey in credential_process output")
190 })?
191 .to_string();
192 let session_token = json["SessionToken"].as_str().map(|s| s.to_string());
193
194 let expires_at = json["Expiration"]
195 .as_str()
196 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
197 .map(|dt| dt.with_timezone(&chrono::Utc));
198
199 let region = env_optional("AWS_REGION")
200 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
201 .or(config_region)
202 .unwrap_or_else(|| DEFAULT_REGION.to_string());
203
204 ::zeroclaw_log::record!(
205 DEBUG,
206 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
207 "Loaded AWS credentials via credential_process"
208 );
209
210 Ok(Self {
211 access_key_id,
212 secret_access_key,
213 session_token,
214 region,
215 expires_at,
216 })
217 }
218
219 async fn from_imds() -> anyhow::Result<Self> {
221 let client = reqwest::Client::builder()
222 .timeout(std::time::Duration::from_secs(3))
223 .build()?;
224
225 let token = client
227 .put("http://169.254.169.254/latest/api/token")
228 .header("X-aws-ec2-metadata-token-ttl-seconds", "21600")
229 .send()
230 .await?
231 .text()
232 .await?;
233
234 let role = client
236 .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
237 .header("X-aws-ec2-metadata-token", &token)
238 .send()
239 .await?
240 .text()
241 .await?;
242 let role = role.trim().to_string();
243 anyhow::ensure!(!role.is_empty(), "No IAM role attached to this instance");
244
245 let creds_url = format!(
247 "http://169.254.169.254/latest/meta-data/iam/security-credentials/{}",
248 role
249 );
250 let creds_json: serde_json::Value = client
251 .get(&creds_url)
252 .header("X-aws-ec2-metadata-token", &token)
253 .send()
254 .await?
255 .json()
256 .await?;
257
258 let access_key_id = creds_json["AccessKeyId"]
259 .as_str()
260 .ok_or_else(|| {
261 ::zeroclaw_log::record!(
262 ERROR,
263 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
264 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
265 .with_attrs(::serde_json::json!({
266 "source": "imds",
267 "missing": "AccessKeyId",
268 })),
269 "bedrock: IMDS response missing AccessKeyId"
270 );
271 anyhow::Error::msg("Missing AccessKeyId in IMDS response")
272 })?
273 .to_string();
274 let secret_access_key = creds_json["SecretAccessKey"]
275 .as_str()
276 .ok_or_else(|| {
277 ::zeroclaw_log::record!(
278 ERROR,
279 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
280 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
281 .with_attrs(::serde_json::json!({
282 "source": "imds",
283 "missing": "SecretAccessKey",
284 })),
285 "bedrock: IMDS response missing SecretAccessKey"
286 );
287 anyhow::Error::msg("Missing SecretAccessKey in IMDS response")
288 })?
289 .to_string();
290 let session_token = creds_json["Token"].as_str().map(|s| s.to_string());
291
292 let region = match client
294 .get("http://169.254.169.254/latest/meta-data/placement/region")
295 .header("X-aws-ec2-metadata-token", &token)
296 .send()
297 .await
298 {
299 Ok(resp) => resp.text().await.unwrap_or_default(),
300 Err(_) => String::new(),
301 };
302 let region = if region.trim().is_empty() {
303 env_optional("AWS_REGION")
304 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
305 .unwrap_or_else(|| DEFAULT_REGION.to_string())
306 } else {
307 region.trim().to_string()
308 };
309
310 ::zeroclaw_log::record!(
311 INFO,
312 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
313 &format!(
314 "Loaded AWS credentials from EC2 instance metadata (role: {})",
315 role
316 )
317 );
318
319 Ok(Self {
320 access_key_id,
321 secret_access_key,
322 session_token,
323 region,
324 expires_at: None,
325 })
326 }
327
328 async fn resolve() -> anyhow::Result<Self> {
330 if let Ok(creds) = Self::from_env() {
331 return Ok(creds);
332 }
333 if let Ok(creds) = Self::from_credential_process() {
334 return Ok(creds);
335 }
336 Self::from_imds().await
337 }
338
339 fn host(&self) -> String {
340 format!("{ENDPOINT_PREFIX}.{}.amazonaws.com", self.region)
341 }
342
343 fn is_expired(&self) -> bool {
346 match self.expires_at {
347 Some(exp) => chrono::Utc::now() >= exp - chrono::Duration::seconds(60),
348 None => false,
349 }
350 }
351}
352
353fn env_required(name: &str) -> anyhow::Result<String> {
354 std::env::var(name)
355 .ok()
356 .map(|v| v.trim().to_string())
357 .filter(|v| !v.is_empty())
358 .ok_or_else(|| {
359 ::zeroclaw_log::record!(
360 ERROR,
361 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
362 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
363 .with_attrs(::serde_json::json!({"env_var": name})),
364 "bedrock: required environment variable is missing"
365 );
366 anyhow::Error::msg(format!(
367 "Environment variable {name} is required for Bedrock"
368 ))
369 })
370}
371
372fn env_optional(name: &str) -> Option<String> {
373 std::env::var(name)
374 .ok()
375 .map(|v| v.trim().to_string())
376 .filter(|v| !v.is_empty())
377}
378
379fn sha256_hex(data: &[u8]) -> String {
382 let mut hasher = Sha256::new();
383 hasher.update(data);
384 hex::encode(hasher.finalize())
385}
386
387fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
388 let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
389 mac.update(data);
390 mac.finalize().into_bytes().to_vec()
391}
392
393fn derive_signing_key(secret: &str, date: &str, region: &str, service: &str) -> Vec<u8> {
395 let k_date = hmac_sha256(format!("AWS4{secret}").as_bytes(), date.as_bytes());
396 let k_region = hmac_sha256(&k_date, region.as_bytes());
397 let k_service = hmac_sha256(&k_region, service.as_bytes());
398 hmac_sha256(&k_service, b"aws4_request")
399}
400
401fn build_authorization_header(
405 credentials: &AwsCredentials,
406 method: &str,
407 canonical_uri: &str,
408 query_string: &str,
409 headers: &[(String, String)],
410 payload: &[u8],
411 timestamp: &chrono::DateTime<chrono::Utc>,
412) -> String {
413 let date_stamp = timestamp.format("%Y%m%d").to_string();
414 let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
415
416 let mut canonical_headers = String::new();
417 for (k, v) in headers {
418 canonical_headers.push_str(k);
419 canonical_headers.push(':');
420 canonical_headers.push_str(v);
421 canonical_headers.push('\n');
422 }
423
424 let signed_headers: String = headers
425 .iter()
426 .map(|(k, _)| k.as_str())
427 .collect::<Vec<_>>()
428 .join(";");
429
430 let payload_hash = sha256_hex(payload);
431
432 let canonical_request = format!(
433 "{method}\n{canonical_uri}\n{query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
434 );
435
436 let credential_scope = format!(
437 "{date_stamp}/{}/{SIGNING_SERVICE}/aws4_request",
438 credentials.region
439 );
440
441 let string_to_sign = format!(
442 "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}",
443 sha256_hex(canonical_request.as_bytes())
444 );
445
446 let signing_key = derive_signing_key(
447 &credentials.secret_access_key,
448 &date_stamp,
449 &credentials.region,
450 SIGNING_SERVICE,
451 );
452
453 let signature = hex::encode(hmac_sha256(&signing_key, string_to_sign.as_bytes()));
454
455 format!(
456 "AWS4-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
457 credentials.access_key_id
458 )
459}
460
461#[derive(Debug, Serialize)]
464#[serde(rename_all = "camelCase")]
465struct ConverseRequest {
466 messages: Vec<ConverseMessage>,
467 #[serde(skip_serializing_if = "Option::is_none")]
468 system: Option<Vec<SystemBlock>>,
469 #[serde(skip_serializing_if = "Option::is_none")]
470 inference_config: Option<InferenceConfig>,
471 #[serde(skip_serializing_if = "Option::is_none")]
472 tool_config: Option<ToolConfig>,
473 #[serde(skip_serializing_if = "Option::is_none")]
474 additional_model_request_fields: Option<serde_json::Value>,
475}
476
477#[derive(Debug, Serialize, Deserialize)]
478struct ConverseMessage {
479 role: String,
480 content: Vec<ContentBlock>,
481}
482
483#[derive(Debug, Serialize, Deserialize)]
490#[serde(untagged)]
491enum ContentBlock {
492 Text(TextBlock),
493 ToolUse(ToolUseWrapper),
494 ToolResult(ToolResultWrapper),
495 CachePointBlock(CachePointWrapper),
496 Image(ImageWrapper),
497 #[serde(rename = "reasoningContent")]
501 ReasoningContent(ReasoningContentOutWrapper),
502}
503
504#[derive(Debug, Serialize, Deserialize)]
507#[serde(rename_all = "camelCase")]
508struct ReasoningContentOutWrapper {
509 reasoning_content: ReasoningContentOutBlock,
510}
511
512#[derive(Debug, Serialize, Deserialize)]
513#[serde(rename_all = "camelCase")]
514struct ReasoningContentOutBlock {
515 reasoning_text: ReasoningTextOutField,
516}
517
518#[derive(Debug, Serialize, Deserialize)]
519struct ReasoningTextOutField {
520 text: String,
521 #[serde(skip_serializing_if = "Option::is_none")]
524 signature: Option<String>,
525}
526
527#[derive(Debug, Serialize, Deserialize)]
528struct ImageWrapper {
529 image: ImageBlock,
530}
531
532#[derive(Debug, Serialize, Deserialize)]
533struct ImageBlock {
534 format: String,
535 source: ImageSource,
536}
537
538#[derive(Debug, Serialize, Deserialize)]
539#[serde(rename_all = "camelCase")]
540struct ImageSource {
541 bytes: String,
542}
543
544#[derive(Debug, Serialize, Deserialize)]
545struct TextBlock {
546 text: String,
547}
548
549#[derive(Debug, Serialize, Deserialize)]
550#[serde(rename_all = "camelCase")]
551struct ToolUseWrapper {
552 tool_use: ToolUseBlock,
553}
554
555#[derive(Debug, Serialize, Deserialize)]
556#[serde(rename_all = "camelCase")]
557struct ToolUseBlock {
558 tool_use_id: String,
559 name: String,
560 input: serde_json::Value,
561}
562
563#[derive(Debug, Serialize, Deserialize)]
564#[serde(rename_all = "camelCase")]
565struct ToolResultWrapper {
566 tool_result: ToolResultBlock,
567}
568
569#[derive(Debug, Serialize, Deserialize)]
570#[serde(rename_all = "camelCase")]
571struct ToolResultBlock {
572 tool_use_id: String,
573 content: Vec<ToolResultContent>,
574 status: String,
575}
576
577#[derive(Debug, Serialize, Deserialize)]
578#[serde(rename_all = "camelCase")]
579struct CachePointWrapper {
580 cache_point: CachePoint,
581}
582
583#[derive(Debug, Serialize, Deserialize)]
584struct ToolResultContent {
585 text: String,
586}
587
588#[derive(Debug, Serialize, Deserialize)]
589struct CachePoint {
590 #[serde(rename = "type")]
591 cache_type: String,
592}
593
594impl CachePoint {
595 fn default_cache() -> Self {
596 Self {
597 cache_type: "default".to_string(),
598 }
599 }
600}
601
602#[derive(Debug, Serialize)]
604#[serde(untagged)]
605enum SystemBlock {
606 Text(TextBlock),
607 CachePoint(CachePointWrapper),
608}
609
610#[derive(Debug, Serialize)]
611#[serde(rename_all = "camelCase")]
612struct InferenceConfig {
613 max_tokens: u32,
614 #[serde(skip_serializing_if = "Option::is_none")]
615 temperature: Option<f64>,
616}
617
618fn bedrock_model_omits_temperature(model: &str) -> bool {
623 model.contains("claude-opus-4-7")
624}
625
626fn bedrock_model_supports_native_thinking(model: &str) -> bool {
634 !model.contains("claude-opus-4-7")
635}
636
637#[derive(Debug, Serialize)]
638#[serde(rename_all = "camelCase")]
639struct ToolConfig {
640 tools: Vec<ToolDefinition>,
641}
642
643#[derive(Debug, Serialize)]
644#[serde(rename_all = "camelCase")]
645struct ToolDefinition {
646 tool_spec: ToolSpecDef,
647}
648
649#[derive(Debug, Serialize)]
650#[serde(rename_all = "camelCase")]
651struct ToolSpecDef {
652 name: String,
653 description: String,
654 input_schema: InputSchema,
655}
656
657#[derive(Debug, Serialize)]
658struct InputSchema {
659 json: serde_json::Value,
660}
661
662#[derive(Debug, Deserialize)]
665#[serde(rename_all = "camelCase")]
666struct ConverseResponse {
667 #[serde(default)]
668 output: Option<ConverseOutput>,
669 #[serde(default)]
670 #[allow(dead_code)]
671 stop_reason: Option<String>,
672 #[serde(default)]
673 usage: Option<BedrockUsage>,
674}
675
676#[derive(Debug, Deserialize)]
677#[serde(rename_all = "camelCase")]
678struct BedrockUsage {
679 #[serde(default)]
680 input_tokens: Option<u64>,
681 #[serde(default)]
682 output_tokens: Option<u64>,
683}
684
685#[derive(Debug, Deserialize)]
686struct ConverseOutput {
687 #[serde(default)]
688 message: Option<ConverseOutputMessage>,
689}
690
691#[derive(Debug, Deserialize)]
692struct ConverseOutputMessage {
693 #[allow(dead_code)]
694 role: String,
695 content: Vec<ResponseContentBlock>,
696}
697
698#[derive(Debug, Deserialize)]
705#[serde(untagged)]
706enum ResponseContentBlock {
707 ToolUse(ResponseToolUseWrapper),
708 ReasoningContent(ReasoningContentWrapper),
709 Text(TextBlock),
710 Other(#[allow(dead_code)] serde_json::Value),
711}
712
713#[derive(Debug, Deserialize)]
714#[serde(rename_all = "camelCase")]
715struct ReasoningContentWrapper {
716 reasoning_content: ReasoningContentBlock,
717}
718
719#[derive(Debug, Deserialize)]
720#[serde(rename_all = "camelCase")]
721struct ReasoningContentBlock {
722 #[serde(default)]
723 reasoning_text: Option<ReasoningTextField>,
724}
725
726#[derive(Debug, Deserialize)]
727struct ReasoningTextField {
728 #[serde(default)]
729 text: Option<String>,
730 #[serde(default)]
733 signature: Option<String>,
734}
735
736#[derive(Debug, Deserialize)]
737#[serde(rename_all = "camelCase")]
738struct ResponseToolUseWrapper {
739 tool_use: ToolUseBlock,
740}
741
742pub struct BedrockModelProvider {
745 alias: String,
747 auth: Option<BedrockAuth>,
748 max_tokens: u32,
749 cred_cache: Mutex<Option<AwsCredentials>>,
751}
752
753impl BedrockModelProvider {
754 pub fn new(alias: &str) -> Self {
755 if let Some(token) = env_optional("BEDROCK_API_KEY") {
757 return Self {
758 alias: alias.to_string(),
759 auth: Some(BedrockAuth::BearerToken(token)),
760 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
761 cred_cache: Mutex::new(None),
762 };
763 }
764 Self {
765 alias: alias.to_string(),
766 auth: AwsCredentials::from_env()
767 .or_else(|_| AwsCredentials::from_credential_process())
768 .ok()
769 .map(BedrockAuth::SigV4),
770 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
771 cred_cache: Mutex::new(None),
772 }
773 }
774
775 pub async fn new_async(alias: &str) -> Self {
776 if let Some(token) = env_optional("BEDROCK_API_KEY") {
778 return Self {
779 alias: alias.to_string(),
780 auth: Some(BedrockAuth::BearerToken(token)),
781 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
782 cred_cache: Mutex::new(None),
783 };
784 }
785 let auth = AwsCredentials::resolve().await.ok().map(BedrockAuth::SigV4);
786 Self {
787 alias: alias.to_string(),
788 auth,
789 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
790 cred_cache: Mutex::new(None),
791 }
792 }
793
794 pub fn with_bearer_token(alias: &str, token: &str) -> Self {
796 Self {
797 alias: alias.to_string(),
798 auth: Some(BedrockAuth::BearerToken(token.to_string())),
799 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
800 cred_cache: Mutex::new(None),
801 }
802 }
803 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
805 self.max_tokens = max_tokens;
806 self
807 }
808
809 fn http_client(&self) -> Client {
810 zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
811 "model_provider.bedrock",
812 120,
813 10,
814 )
815 }
816
817 fn encode_model_path(model_id: &str) -> String {
821 model_id.replace(':', "%3A")
822 }
823
824 fn resolve_region() -> String {
826 env_optional("AWS_REGION")
827 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
828 .unwrap_or_else(|| DEFAULT_REGION.to_string())
829 }
830
831 fn endpoint_url(region: &str, model_id: &str) -> String {
833 format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse")
834 }
835
836 fn canonical_uri(model_id: &str) -> String {
840 let encoded = Self::encode_model_path(model_id);
841 format!("/model/{encoded}/converse")
842 }
843
844 fn cached_credentials(&self) -> Option<AwsCredentials> {
846 let cache = self.cred_cache.lock().ok()?;
847 let creds = cache.as_ref()?;
848 if creds.is_expired() {
849 return None;
850 }
851 Some(creds.clone())
852 }
853
854 fn cache_credentials(&self, creds: &AwsCredentials) {
856 if let Ok(mut cache) = self.cred_cache.lock() {
857 *cache = Some(creds.clone());
858 }
859 }
860
861 async fn resolve_auth(&self) -> anyhow::Result<BedrockAuth> {
863 if let Some(ref auth) = self.auth {
865 match auth {
866 BedrockAuth::BearerToken(token) => {
867 return Ok(BedrockAuth::BearerToken(token.clone()));
868 }
869 BedrockAuth::SigV4(_) => {
870 if let Some(creds) = self.cached_credentials() {
871 return Ok(BedrockAuth::SigV4(creds));
872 }
873 }
874 }
875 }
876 if let Some(token) = env_optional("BEDROCK_API_KEY") {
878 return Ok(BedrockAuth::BearerToken(token));
879 }
880 if let Ok(creds) = AwsCredentials::from_env() {
882 return Ok(BedrockAuth::SigV4(creds));
883 }
884 if let Ok(creds) = AwsCredentials::from_credential_process() {
885 self.cache_credentials(&creds);
886 return Ok(BedrockAuth::SigV4(creds));
887 }
888 Ok(BedrockAuth::SigV4(AwsCredentials::from_imds().await?))
889 }
890
891 fn should_cache_system(text: &str) -> bool {
895 text.len() > 3072
896 }
897
898 fn should_cache_conversation(messages: &[ChatMessage]) -> bool {
900 messages.iter().filter(|m| m.role != "system").count() > 4
901 }
902
903 fn convert_messages(
906 messages: &[ChatMessage],
907 ) -> (Option<Vec<SystemBlock>>, Vec<ConverseMessage>) {
908 let mut system_blocks = Vec::new();
909 let mut converse_messages = Vec::new();
910
911 for msg in messages {
912 match msg.role.as_str() {
913 "system" => {
914 if system_blocks.is_empty() {
915 system_blocks.push(SystemBlock::Text(TextBlock {
916 text: msg.content.clone(),
917 }));
918 }
919 }
920 "assistant" => {
921 if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) {
922 converse_messages.push(ConverseMessage {
923 role: "assistant".to_string(),
924 content: blocks,
925 });
926 } else {
927 let text = if msg.content.trim().is_empty() {
932 "(empty response)".to_string()
933 } else {
934 msg.content.clone()
935 };
936 converse_messages.push(ConverseMessage {
937 role: "assistant".to_string(),
938 content: vec![ContentBlock::Text(TextBlock { text })],
939 });
940 }
941 }
942 "tool" => {
943 let tool_result_msg = Self::parse_tool_result_message(&msg.content)
944 .unwrap_or_else(|| {
945 let tool_use_id = Self::extract_tool_call_id(&msg.content)
949 .or_else(|| Self::last_pending_tool_use_id(&converse_messages))
950 .unwrap_or_else(|| "unknown".to_string());
951
952 ::zeroclaw_log::record!(
953 WARN,
954 ::zeroclaw_log::Event::new(
955 module_path!(),
956 ::zeroclaw_log::Action::Note
957 )
958 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
959 &format!(
960 "Failed to parse tool result message, creating error \
961 toolResult for tool_use_id={}",
962 tool_use_id
963 )
964 );
965
966 ConverseMessage {
967 role: "user".to_string(),
968 content: vec![ContentBlock::ToolResult(ToolResultWrapper {
969 tool_result: ToolResultBlock {
970 tool_use_id,
971 content: vec![ToolResultContent {
972 text: msg.content.clone(),
973 }],
974 status: "error".to_string(),
975 },
976 })],
977 }
978 });
979
980 if let Some(last) = converse_messages.last_mut()
984 && last.role == "user"
985 && last
986 .content
987 .iter()
988 .all(|b| matches!(b, ContentBlock::ToolResult(_)))
989 {
990 last.content.extend(tool_result_msg.content);
991 continue;
992 }
993 converse_messages.push(tool_result_msg);
994 }
995 _ => {
996 let content_blocks = Self::parse_user_content_blocks(&msg.content);
997 converse_messages.push(ConverseMessage {
998 role: "user".to_string(),
999 content: content_blocks,
1000 });
1001 }
1002 }
1003 }
1004
1005 let system = if system_blocks.is_empty() {
1006 None
1007 } else {
1008 Some(system_blocks)
1009 };
1010 (system, converse_messages)
1011 }
1012
1013 fn sanitize_empty_content_blocks(messages: &mut [ConverseMessage]) {
1021 for msg in messages.iter_mut() {
1022 msg.content.retain(|block| match block {
1023 ContentBlock::Text(tb) => !tb.text.trim().is_empty(),
1024 _ => true,
1025 });
1026 if msg.content.is_empty() {
1027 msg.content.push(ContentBlock::Text(TextBlock {
1028 text: "(empty)".to_string(),
1029 }));
1030 }
1031 }
1032 }
1033
1034 fn extract_tool_call_id(content: &str) -> Option<String> {
1036 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
1037 value
1038 .get("tool_call_id")
1039 .or_else(|| value.get("tool_use_id"))
1040 .or_else(|| value.get("toolUseId"))
1041 .and_then(serde_json::Value::as_str)
1042 .map(String::from)
1043 }
1044
1045 fn last_pending_tool_use_id(converse_messages: &[ConverseMessage]) -> Option<String> {
1051 let last_assistant = converse_messages
1052 .iter()
1053 .rev()
1054 .find(|m| m.role == "assistant")?;
1055
1056 let tool_use_ids: Vec<&str> = last_assistant
1057 .content
1058 .iter()
1059 .filter_map(|b| match b {
1060 ContentBlock::ToolUse(wrapper) => Some(wrapper.tool_use.tool_use_id.as_str()),
1061 _ => None,
1062 })
1063 .collect();
1064
1065 let answered_ids: Vec<&str> = converse_messages
1066 .iter()
1067 .rev()
1068 .take_while(|m| m.role == "user")
1069 .flat_map(|m| m.content.iter())
1070 .filter_map(|b| match b {
1071 ContentBlock::ToolResult(wrapper) => Some(wrapper.tool_result.tool_use_id.as_str()),
1072 _ => None,
1073 })
1074 .collect();
1075
1076 tool_use_ids
1077 .into_iter()
1078 .find(|id| !answered_ids.contains(id))
1079 .map(String::from)
1080 }
1081
1082 fn parse_user_content_blocks(content: &str) -> Vec<ContentBlock> {
1084 let mut blocks: Vec<ContentBlock> = Vec::new();
1085 let mut remaining = content;
1086 let has_image = content.contains("[IMAGE:");
1087 ::zeroclaw_log::record!(
1088 INFO,
1089 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
1090 &format!(
1091 "parse_user_content_blocks called, len={}, has_image={}",
1092 content.len(),
1093 has_image
1094 )
1095 );
1096
1097 while let Some(start) = remaining.find("[IMAGE:") {
1098 let text_before = &remaining[..start];
1100 if !text_before.trim().is_empty() {
1101 blocks.push(ContentBlock::Text(TextBlock {
1102 text: text_before.to_string(),
1103 }));
1104 }
1105
1106 let after = &remaining[start + 7..]; if let Some(end) = after.find(']') {
1108 let src = &after[..end];
1109 remaining = &after[end + 1..];
1110
1111 if let Some(rest) = src.strip_prefix("data:")
1113 && let Some(semi) = rest.find(';')
1114 {
1115 let mime = &rest[..semi];
1116 let after_semi = &rest[semi + 1..];
1117 if let Some(b64) = after_semi.strip_prefix("base64,") {
1118 let format = match mime {
1119 "image/png" => "png",
1120 "image/gif" => "gif",
1121 "image/webp" => "webp",
1122 _ => "jpeg",
1123 };
1124 blocks.push(ContentBlock::Image(ImageWrapper {
1125 image: ImageBlock {
1126 format: format.to_string(),
1127 source: ImageSource {
1128 bytes: b64.to_string(),
1129 },
1130 },
1131 }));
1132 continue;
1133 }
1134 }
1135 blocks.push(ContentBlock::Text(TextBlock {
1137 text: format!("[image: {}]", src),
1138 }));
1139 } else {
1140 blocks.push(ContentBlock::Text(TextBlock {
1142 text: remaining.to_string(),
1143 }));
1144 break;
1145 }
1146 }
1147
1148 if !remaining.trim().is_empty() {
1150 blocks.push(ContentBlock::Text(TextBlock {
1151 text: remaining.to_string(),
1152 }));
1153 }
1154
1155 if blocks.is_empty() {
1156 let fallback = if content.trim().is_empty() {
1157 "(empty)".to_string()
1158 } else {
1159 content.to_string()
1160 };
1161 blocks.push(ContentBlock::Text(TextBlock { text: fallback }));
1162 }
1163
1164 blocks
1165 }
1166
1167 fn parse_assistant_tool_call_message(content: &str) -> Option<Vec<ContentBlock>> {
1169 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
1170 let tool_calls = value
1171 .get("tool_calls")
1172 .and_then(|v| serde_json::from_value::<Vec<ProviderToolCall>>(v.clone()).ok())?;
1173
1174 let mut blocks = Vec::new();
1175
1176 if let Some(reasoning) = value
1181 .get("reasoning_content")
1182 .and_then(serde_json::Value::as_str)
1183 .filter(|r| !r.is_empty())
1184 {
1185 for part in reasoning.split('\n') {
1187 if let Ok(block) = serde_json::from_str::<serde_json::Value>(part) {
1188 let text = block
1189 .get("text")
1190 .and_then(|t| t.as_str())
1191 .unwrap_or("")
1192 .to_string();
1193 let signature = block
1194 .get("signature")
1195 .and_then(|s| s.as_str())
1196 .filter(|s| !s.is_empty())
1197 .map(|s| s.to_string());
1198 blocks.push(ContentBlock::ReasoningContent(ReasoningContentOutWrapper {
1199 reasoning_content: ReasoningContentOutBlock {
1200 reasoning_text: ReasoningTextOutField { text, signature },
1201 },
1202 }));
1203 }
1204 }
1205 }
1206
1207 if let Some(text) = value
1208 .get("content")
1209 .and_then(serde_json::Value::as_str)
1210 .map(str::trim)
1211 .filter(|t| !t.is_empty())
1212 {
1213 blocks.push(ContentBlock::Text(TextBlock {
1214 text: text.to_string(),
1215 }));
1216 }
1217 for call in tool_calls {
1218 let input = serde_json::from_str::<serde_json::Value>(&call.arguments)
1219 .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
1220 blocks.push(ContentBlock::ToolUse(ToolUseWrapper {
1221 tool_use: ToolUseBlock {
1222 tool_use_id: call.id,
1223 name: call.name,
1224 input,
1225 },
1226 }));
1227 }
1228 Some(blocks)
1229 }
1230
1231 fn parse_tool_result_message(content: &str) -> Option<ConverseMessage> {
1233 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
1234 let tool_use_id = value
1235 .get("tool_call_id")
1236 .or_else(|| value.get("tool_use_id"))
1237 .or_else(|| value.get("toolUseId"))
1238 .and_then(serde_json::Value::as_str)?
1239 .to_string();
1240 let result = value
1241 .get("content")
1242 .and_then(serde_json::Value::as_str)
1243 .unwrap_or("")
1244 .to_string();
1245 Some(ConverseMessage {
1246 role: "user".to_string(),
1247 content: vec![ContentBlock::ToolResult(ToolResultWrapper {
1248 tool_result: ToolResultBlock {
1249 tool_use_id,
1250 content: vec![ToolResultContent { text: result }],
1251 status: "success".to_string(),
1252 },
1253 })],
1254 })
1255 }
1256
1257 fn convert_tools_to_converse(tools: Option<&[ToolSpec]>) -> Option<ToolConfig> {
1260 let items = tools?;
1261 if items.is_empty() {
1262 return None;
1263 }
1264 let tool_defs: Vec<ToolDefinition> = items
1265 .iter()
1266 .map(|tool| ToolDefinition {
1267 tool_spec: ToolSpecDef {
1268 name: tool.name.clone(),
1269 description: tool.description.clone(),
1270 input_schema: InputSchema {
1271 json: tool.parameters.clone(),
1272 },
1273 },
1274 })
1275 .collect();
1276 Some(ToolConfig { tools: tool_defs })
1277 }
1278
1279 fn parse_converse_response(response: ConverseResponse) -> ProviderChatResponse {
1282 let mut text_parts = Vec::new();
1283 let mut thinking_parts = Vec::new();
1284 let mut tool_calls = Vec::new();
1285
1286 let usage = response.usage.map(|u| TokenUsage {
1287 input_tokens: u.input_tokens,
1288 output_tokens: u.output_tokens,
1289 cached_input_tokens: None,
1290 });
1291
1292 if let Some(output) = response.output
1293 && let Some(message) = output.message
1294 {
1295 for block in message.content {
1296 match block {
1297 ResponseContentBlock::Text(tb) => {
1298 let trimmed = tb.text.trim().to_string();
1299 if !trimmed.is_empty() {
1300 text_parts.push(trimmed);
1301 }
1302 }
1303 ResponseContentBlock::ReasoningContent(wrapper) => {
1304 if let Some(reasoning_text) = wrapper.reasoning_content.reasoning_text {
1305 let block = serde_json::json!({
1307 "text": reasoning_text.text.as_deref().unwrap_or(""),
1308 "signature": reasoning_text.signature.as_deref().unwrap_or(""),
1309 });
1310 thinking_parts.push(block.to_string());
1311 }
1312 }
1313 ResponseContentBlock::ToolUse(wrapper) => {
1314 if !wrapper.tool_use.name.is_empty() {
1315 tool_calls.push(ProviderToolCall {
1316 id: wrapper.tool_use.tool_use_id,
1317 name: wrapper.tool_use.name,
1318 arguments: wrapper.tool_use.input.to_string(),
1319 extra_content: None,
1320 });
1321 }
1322 }
1323 ResponseContentBlock::Other(_) => {}
1324 }
1325 }
1326 }
1327
1328 let reasoning_content = if thinking_parts.is_empty() {
1329 None
1330 } else {
1331 Some(thinking_parts.join("\n"))
1332 };
1333
1334 ProviderChatResponse {
1335 text: if text_parts.is_empty() {
1336 None
1337 } else {
1338 Some(text_parts.join("\n"))
1339 },
1340 tool_calls,
1341 usage,
1342 reasoning_content,
1343 }
1344 }
1345
1346 async fn send_converse_request(
1349 &self,
1350 auth: &BedrockAuth,
1351 model: &str,
1352 request_body: &ConverseRequest,
1353 ) -> anyhow::Result<ConverseResponse> {
1354 let payload = serde_json::to_vec(request_body)?;
1355
1356 if let Ok(debug_val) = serde_json::from_slice::<serde_json::Value>(&payload)
1358 && let Some(msgs) = debug_val.get("messages").and_then(|m| m.as_array())
1359 {
1360 for msg in msgs {
1361 if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
1362 for block in content {
1363 if block.get("image").is_some() {
1364 let mut b = block.clone();
1365 if let Some(img) = b.get_mut("image")
1366 && let Some(src) = img.get_mut("source")
1367 && let Some(bytes) = src.get_mut("bytes")
1368 && let Some(s) = bytes.as_str()
1369 {
1370 *bytes = serde_json::json!(format!("<base64 {} chars>", s.len()));
1371 }
1372 ::zeroclaw_log::record!(
1373 INFO,
1374 ::zeroclaw_log::Event::new(
1375 module_path!(),
1376 ::zeroclaw_log::Action::Note
1377 ),
1378 &format!(
1379 "Bedrock image block: {}",
1380 serde_json::to_string(&b).unwrap_or_default()
1381 )
1382 );
1383 }
1384 }
1385 }
1386 }
1387 }
1388
1389 let response: reqwest::Response = match auth {
1390 BedrockAuth::BearerToken(token) => {
1391 let region = Self::resolve_region();
1392 let url = Self::endpoint_url(®ion, model);
1393
1394 self.http_client()
1395 .post(&url)
1396 .header("content-type", "application/json")
1397 .header("Authorization", format!("Bearer {token}"))
1398 .body(payload)
1399 .send()
1400 .await?
1401 }
1402 BedrockAuth::SigV4(credentials) => {
1403 let url = Self::endpoint_url(&credentials.region, model);
1404 let canonical_uri = Self::canonical_uri(model);
1405 let now = chrono::Utc::now();
1406 let host = credentials.host();
1407 let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
1408
1409 let mut headers_to_sign = vec![
1410 ("content-type".to_string(), "application/json".to_string()),
1411 ("host".to_string(), host),
1412 ("x-amz-date".to_string(), amz_date.clone()),
1413 ];
1414 if let Some(ref session_token) = credentials.session_token {
1415 headers_to_sign
1416 .push(("x-amz-security-token".to_string(), session_token.clone()));
1417 }
1418 headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0));
1419
1420 let authorization = build_authorization_header(
1421 credentials,
1422 "POST",
1423 &canonical_uri,
1424 "",
1425 &headers_to_sign,
1426 &payload,
1427 &now,
1428 );
1429
1430 let mut request = self
1431 .http_client()
1432 .post(&url)
1433 .header("content-type", "application/json")
1434 .header("x-amz-date", &amz_date)
1435 .header("authorization", &authorization);
1436
1437 if let Some(ref session_token) = credentials.session_token {
1438 request = request.header("x-amz-security-token", session_token);
1439 }
1440
1441 request.body(payload).send().await?
1442 }
1443 };
1444
1445 if !response.status().is_success() {
1446 return Err(super::api_error("Bedrock", response).await);
1447 }
1448
1449 let converse_response: ConverseResponse = response.json().await?;
1450 Ok(converse_response)
1451 }
1452}
1453
1454#[async_trait]
1457impl ModelProvider for BedrockModelProvider {
1458 fn capabilities(&self) -> ProviderCapabilities {
1459 ProviderCapabilities {
1460 native_tool_calling: true,
1461 vision: true,
1462 prompt_caching: false,
1463 extended_thinking: true,
1464 }
1465 }
1466
1467 fn supports_native_tools(&self) -> bool {
1468 true
1469 }
1470
1471 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
1472 let tool_values: Vec<serde_json::Value> = tools
1473 .iter()
1474 .map(|t| {
1475 serde_json::json!({
1476 "toolSpec": {
1477 "name": t.name,
1478 "description": t.description,
1479 "inputSchema": { "json": t.parameters }
1480 }
1481 })
1482 })
1483 .collect();
1484 ToolsPayload::Anthropic { tools: tool_values }
1485 }
1486
1487 async fn chat_with_system(
1488 &self,
1489 system_prompt: Option<&str>,
1490 message: &str,
1491 model: &str,
1492 temperature: Option<f64>,
1493 ) -> anyhow::Result<String> {
1494 let temperature = temperature.unwrap_or(self.default_temperature());
1495 let auth = self.resolve_auth().await?;
1496
1497 let system = system_prompt.map(|text| {
1498 let mut blocks = vec![SystemBlock::Text(TextBlock {
1499 text: text.to_string(),
1500 })];
1501 if Self::should_cache_system(text) {
1502 blocks.push(SystemBlock::CachePoint(CachePointWrapper {
1503 cache_point: CachePoint::default_cache(),
1504 }));
1505 }
1506 blocks
1507 });
1508
1509 let mut messages = vec![ConverseMessage {
1510 role: "user".to_string(),
1511 content: Self::parse_user_content_blocks(message),
1512 }];
1513 Self::sanitize_empty_content_blocks(&mut messages);
1514
1515 let request = ConverseRequest {
1516 system,
1517 messages,
1518 inference_config: Some(InferenceConfig {
1519 max_tokens: self.max_tokens,
1520 temperature: if bedrock_model_omits_temperature(model) {
1521 None
1522 } else {
1523 Some(temperature)
1524 },
1525 }),
1526 tool_config: None,
1527 additional_model_request_fields: None,
1528 };
1529
1530 let response = self.send_converse_request(&auth, model, &request).await?;
1531
1532 Self::parse_converse_response(response).text.ok_or_else(|| {
1533 ::zeroclaw_log::record!(
1534 ERROR,
1535 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1536 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
1537 "bedrock: empty text in response"
1538 );
1539 anyhow::Error::msg("No response from Bedrock")
1540 })
1541 }
1542
1543 async fn chat(
1544 &self,
1545 request: ProviderChatRequest<'_>,
1546 model: &str,
1547 temperature: Option<f64>,
1548 ) -> anyhow::Result<ProviderChatResponse> {
1549 let temperature = temperature.unwrap_or(self.default_temperature());
1550 let auth = self.resolve_auth().await?;
1551
1552 let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages);
1553
1554 Self::sanitize_empty_content_blocks(&mut converse_messages);
1556
1557 let system = system_blocks.map(|mut blocks| {
1559 let has_large_system = blocks
1560 .iter()
1561 .any(|b| matches!(b, SystemBlock::Text(tb) if Self::should_cache_system(&tb.text)));
1562 if has_large_system {
1563 blocks.push(SystemBlock::CachePoint(CachePointWrapper {
1564 cache_point: CachePoint::default_cache(),
1565 }));
1566 }
1567 blocks
1568 });
1569
1570 if Self::should_cache_conversation(request.messages)
1572 && let Some(last_msg) = converse_messages.last_mut()
1573 {
1574 last_msg
1575 .content
1576 .push(ContentBlock::CachePointBlock(CachePointWrapper {
1577 cache_point: CachePoint::default_cache(),
1578 }));
1579 }
1580
1581 let tool_config = Self::convert_tools_to_converse(request.tools);
1582
1583 let native_thinking_active =
1589 request.thinking.is_some() && bedrock_model_supports_native_thinking(model);
1590 let (effective_temperature, additional_fields, effective_max_tokens) = match request
1591 .thinking
1592 {
1593 Some(params) if bedrock_model_supports_native_thinking(model) => {
1594 ::zeroclaw_log::record!(
1595 INFO,
1596 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1597 .with_attrs(::serde_json::json!({"budget_tokens": params.budget_tokens})),
1598 "Bedrock native extended thinking enabled; forcing temperature=1.0"
1599 );
1600 let fields = serde_json::json!({
1601 "thinking": {
1602 "type": "enabled",
1603 "budget_tokens": params.budget_tokens
1604 }
1605 });
1606 let min_required = params.budget_tokens + 1;
1608 let max_tokens = self.max_tokens.max(min_required);
1609 (1.0, Some(fields), max_tokens)
1610 }
1611 Some(_) => {
1612 ::zeroclaw_log::record!(
1613 WARN,
1614 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1615 .with_attrs(::serde_json::json!({"model": model})),
1616 "Native extended thinking requested but model only supports adaptive thinking; falling back to prompt-based reasoning"
1617 );
1618 (temperature, None, self.max_tokens)
1619 }
1620 None => (temperature, None, self.max_tokens),
1621 };
1622
1623 let serialized_temperature = if native_thinking_active {
1627 Some(effective_temperature)
1628 } else if bedrock_model_omits_temperature(model) {
1629 None
1630 } else {
1631 Some(effective_temperature)
1632 };
1633
1634 let converse_request = ConverseRequest {
1635 system,
1636 messages: converse_messages,
1637 inference_config: Some(InferenceConfig {
1638 max_tokens: effective_max_tokens,
1639 temperature: serialized_temperature,
1640 }),
1641 tool_config,
1642 additional_model_request_fields: additional_fields,
1643 };
1644
1645 let response = self
1646 .send_converse_request(&auth, model, &converse_request)
1647 .await?;
1648
1649 Ok(Self::parse_converse_response(response))
1650 }
1651
1652 async fn warmup(&self) -> anyhow::Result<()> {
1653 let region = match self.auth {
1654 Some(BedrockAuth::SigV4(ref creds)) => creds.region.clone(),
1655 Some(BedrockAuth::BearerToken(_)) => Self::resolve_region(),
1656 None => return Ok(()),
1657 };
1658 let url = format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/");
1659 let _ = self.http_client().get(&url).send().await;
1660 Ok(())
1661 }
1662}
1663
1664impl ::zeroclaw_api::attribution::Attributable for BedrockModelProvider {
1667 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1668 ::zeroclaw_api::attribution::Role::Provider(
1669 ::zeroclaw_api::attribution::ProviderKind::Model(
1670 ::zeroclaw_api::attribution::ModelProviderKind::Bedrock,
1671 ),
1672 )
1673 }
1674 fn alias(&self) -> &str {
1675 &self.alias
1676 }
1677}
1678
1679#[cfg(test)]
1680mod tests {
1681 use super::*;
1682 use crate::test_util::{EnvGuard, env_lock};
1683 use crate::traits::ChatMessage;
1684
1685 #[test]
1688 fn sha256_hex_empty_string() {
1689 assert_eq!(
1691 sha256_hex(b""),
1692 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
1693 );
1694 }
1695
1696 #[test]
1697 fn sha256_hex_known_input() {
1698 assert_eq!(
1700 sha256_hex(b"hello"),
1701 "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
1702 );
1703 }
1704
1705 const TEST_VECTOR_SECRET: &str = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY";
1707
1708 #[test]
1709 fn hmac_sha256_known_input() {
1710 let test_key: &[u8] = b"key";
1711 let result = hmac_sha256(test_key, b"message");
1712 assert_eq!(
1713 hex::encode(&result),
1714 "6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a"
1715 );
1716 }
1717
1718 #[test]
1719 fn derive_signing_key_structure() {
1720 let key = derive_signing_key(TEST_VECTOR_SECRET, "20150830", "us-east-1", "iam");
1722 assert_eq!(key.len(), 32);
1723 }
1724
1725 #[test]
1726 fn derive_signing_key_known_test_vector() {
1727 let key = derive_signing_key(TEST_VECTOR_SECRET, "20150830", "us-east-1", "iam");
1729 assert_eq!(
1730 hex::encode(&key),
1731 "c4afb1cc5771d871763a393e44b703571b55cc28424d1a5e86da6ed3c154a4b9"
1732 );
1733 }
1734
1735 #[test]
1736 fn build_authorization_header_format() {
1737 let credentials = AwsCredentials {
1738 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1739 secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
1740 session_token: None,
1741 region: "us-east-1".to_string(),
1742 expires_at: None,
1743 };
1744
1745 let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z")
1746 .unwrap()
1747 .with_timezone(&chrono::Utc);
1748
1749 let headers = vec![
1750 ("content-type".to_string(), "application/json".to_string()),
1751 (
1752 "host".to_string(),
1753 "bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1754 ),
1755 ("x-amz-date".to_string(), "20240115T120000Z".to_string()),
1756 ];
1757
1758 let auth = build_authorization_header(
1759 &credentials,
1760 "POST",
1761 "/model/anthropic.claude-3-sonnet/converse",
1762 "",
1763 &headers,
1764 b"{}",
1765 ×tamp,
1766 );
1767
1768 assert!(auth.starts_with("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/"));
1770 assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
1771 assert!(auth.contains("Signature="));
1772 assert!(auth.contains("/us-east-1/bedrock/aws4_request"));
1773 }
1774
1775 #[test]
1776 fn build_authorization_header_includes_security_token_in_signed_headers() {
1777 let credentials = AwsCredentials {
1778 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1779 secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
1780 session_token: Some("session-token-value".to_string()),
1781 region: "us-east-1".to_string(),
1782 expires_at: None,
1783 };
1784
1785 let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z")
1786 .unwrap()
1787 .with_timezone(&chrono::Utc);
1788
1789 let headers = vec![
1790 ("content-type".to_string(), "application/json".to_string()),
1791 (
1792 "host".to_string(),
1793 "bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1794 ),
1795 ("x-amz-date".to_string(), "20240115T120000Z".to_string()),
1796 (
1797 "x-amz-security-token".to_string(),
1798 "session-token-value".to_string(),
1799 ),
1800 ];
1801
1802 let auth = build_authorization_header(
1803 &credentials,
1804 "POST",
1805 "/model/test-model/converse",
1806 "",
1807 &headers,
1808 b"{}",
1809 ×tamp,
1810 );
1811
1812 assert!(auth.contains("x-amz-security-token"));
1813 }
1814
1815 #[test]
1818 fn credentials_host_formats_correctly() {
1819 let creds = AwsCredentials {
1820 access_key_id: "AKID".to_string(),
1821 secret_access_key: "secret".to_string(),
1822 session_token: None,
1823 region: "us-west-2".to_string(),
1824 expires_at: None,
1825 };
1826 assert_eq!(creds.host(), "bedrock-runtime.us-west-2.amazonaws.com");
1827 }
1828
1829 #[test]
1832 fn creates_without_credentials() {
1833 let _provider = BedrockModelProvider::new("test");
1835 }
1836
1837 #[tokio::test]
1838 #[allow(clippy::await_holding_lock)]
1839 async fn chat_fails_without_credentials() {
1840 let _env_lock = env_lock();
1841 let _ak = EnvGuard::set("AWS_ACCESS_KEY_ID", None);
1842 let _sk = EnvGuard::set("AWS_SECRET_ACCESS_KEY", None);
1843 let _bearer = EnvGuard::set("BEDROCK_API_KEY", None);
1844 let _config = EnvGuard::set("AWS_CONFIG_FILE", Some("/dev/null"));
1845 let model_provider = BedrockModelProvider {
1846 alias: "test".to_string(),
1847 auth: None,
1848 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
1849 cred_cache: Mutex::new(None),
1850 };
1851 let result = model_provider
1852 .chat_with_system(None, "hello", "anthropic.claude-sonnet-4-6", Some(0.7))
1853 .await;
1854 assert!(result.is_err());
1855 let err = result.unwrap_err().to_string();
1856 assert!(
1857 err.contains("credentials not set")
1858 || err.contains("169.254.169.254")
1859 || err.to_lowercase().contains("credential")
1860 || err.to_lowercase().contains("builder error"),
1861 "Expected missing-credentials style error, got: {err}"
1862 );
1863 }
1864
1865 #[test]
1868 fn creates_with_bearer_token() {
1869 let model_provider = BedrockModelProvider::with_bearer_token("test", "test-api-key");
1870 assert!(model_provider.auth.is_some());
1871 assert!(
1872 matches!(model_provider.auth, Some(BedrockAuth::BearerToken(ref t)) if t == "test-api-key")
1873 );
1874 }
1875
1876 #[test]
1877 fn bearer_token_from_env() {
1878 let _env_lock = env_lock();
1879 let _guard = EnvGuard::set("BEDROCK_API_KEY", Some("env-bearer-token"));
1880 let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", None);
1882 let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", None);
1883
1884 let model_provider = BedrockModelProvider::new("test");
1885 assert!(matches!(
1886 model_provider.auth,
1887 Some(BedrockAuth::BearerToken(ref t)) if t == "env-bearer-token"
1888 ));
1889 }
1890
1891 #[test]
1892 fn bearer_token_precedence() {
1893 let _env_lock = env_lock();
1894 let _bearer_guard = EnvGuard::set("BEDROCK_API_KEY", Some("bearer-key"));
1895 let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", Some("AKIAEXAMPLE"));
1896 let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", Some("secret"));
1897
1898 let model_provider = BedrockModelProvider::new("test");
1899 assert!(matches!(
1901 model_provider.auth,
1902 Some(BedrockAuth::BearerToken(ref t)) if t == "bearer-key"
1903 ));
1904 }
1905
1906 #[test]
1909 fn endpoint_url_formats_correctly() {
1910 let url = BedrockModelProvider::endpoint_url("us-east-1", "anthropic.claude-sonnet-4-6");
1911 assert_eq!(
1912 url,
1913 "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-6/converse"
1914 );
1915 }
1916
1917 #[test]
1918 fn endpoint_url_keeps_raw_colon() {
1919 let url = BedrockModelProvider::endpoint_url(
1921 "us-west-2",
1922 "anthropic.claude-3-5-haiku-20241022-v1:0",
1923 );
1924 assert!(url.contains("/model/anthropic.claude-3-5-haiku-20241022-v1:0/converse"));
1925 }
1926
1927 #[test]
1928 fn canonical_uri_encodes_colon() {
1929 let uri = BedrockModelProvider::canonical_uri("anthropic.claude-3-5-haiku-20241022-v1:0");
1931 assert_eq!(
1932 uri,
1933 "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse"
1934 );
1935 }
1936
1937 #[test]
1938 fn canonical_uri_no_colon_unchanged() {
1939 let uri = BedrockModelProvider::canonical_uri("anthropic.claude-sonnet-4-6");
1940 assert_eq!(uri, "/model/anthropic.claude-sonnet-4-6/converse");
1941 }
1942
1943 #[test]
1946 fn convert_messages_system_extracted() {
1947 let messages = vec![
1948 ChatMessage::system("You are helpful"),
1949 ChatMessage::user("Hello"),
1950 ];
1951 let (system, msgs) = BedrockModelProvider::convert_messages(&messages);
1952 assert!(system.is_some());
1953 let system_blocks = system.unwrap();
1954 assert_eq!(system_blocks.len(), 1);
1955 assert_eq!(msgs.len(), 1);
1956 assert_eq!(msgs[0].role, "user");
1957 }
1958
1959 #[test]
1960 fn convert_messages_user_and_assistant() {
1961 let messages = vec![
1962 ChatMessage::user("Hello"),
1963 ChatMessage::assistant("Hi there"),
1964 ];
1965 let (system, msgs) = BedrockModelProvider::convert_messages(&messages);
1966 assert!(system.is_none());
1967 assert_eq!(msgs.len(), 2);
1968 assert_eq!(msgs[0].role, "user");
1969 assert_eq!(msgs[1].role, "assistant");
1970 }
1971
1972 #[test]
1973 fn convert_messages_tool_role_to_tool_result() {
1974 let tool_json = r#"{"tool_call_id": "call_123", "content": "Result data"}"#;
1975 let messages = vec![ChatMessage::tool(tool_json)];
1976 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
1977 assert_eq!(msgs.len(), 1);
1978 assert_eq!(msgs[0].role, "user");
1979 assert!(matches!(msgs[0].content[0], ContentBlock::ToolResult(_)));
1980 }
1981
1982 #[test]
1983 fn convert_messages_assistant_tool_calls_parsed() {
1984 let tool_call_json = r#"{"content": "Let me check", "tool_calls": [{"id": "call_1", "name": "shell", "arguments": "{\"command\":\"ls\"}"}]}"#;
1985 let messages = vec![ChatMessage::assistant(tool_call_json)];
1986 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
1987 assert_eq!(msgs.len(), 1);
1988 assert_eq!(msgs[0].role, "assistant");
1989 assert_eq!(msgs[0].content.len(), 2);
1990 assert!(matches!(msgs[0].content[0], ContentBlock::Text(_)));
1991 assert!(matches!(msgs[0].content[1], ContentBlock::ToolUse(_)));
1992 }
1993
1994 #[test]
1995 fn convert_messages_plain_assistant_text() {
1996 let messages = vec![ChatMessage::assistant("Just text")];
1997 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
1998 assert_eq!(msgs.len(), 1);
1999 assert!(matches!(msgs[0].content[0], ContentBlock::Text(_)));
2000 }
2001
2002 #[test]
2005 fn should_cache_system_small_prompt() {
2006 assert!(!BedrockModelProvider::should_cache_system("Short prompt"));
2007 }
2008
2009 #[test]
2010 fn should_cache_system_large_prompt() {
2011 let large = "a".repeat(3073);
2012 assert!(BedrockModelProvider::should_cache_system(&large));
2013 }
2014
2015 #[test]
2016 fn should_cache_system_boundary() {
2017 assert!(!BedrockModelProvider::should_cache_system(
2018 &"a".repeat(3072)
2019 ));
2020 assert!(BedrockModelProvider::should_cache_system(&"a".repeat(3073)));
2021 }
2022
2023 #[test]
2024 fn should_cache_conversation_short() {
2025 let messages = vec![
2026 ChatMessage::system("System"),
2027 ChatMessage::user("Hello"),
2028 ChatMessage::assistant("Hi"),
2029 ];
2030 assert!(!BedrockModelProvider::should_cache_conversation(&messages));
2031 }
2032
2033 #[test]
2034 fn should_cache_conversation_long() {
2035 let mut messages = vec![ChatMessage::system("System")];
2036 for i in 0..5 {
2037 messages.push(ChatMessage {
2038 role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(),
2039 content: format!("Message {i}"),
2040 });
2041 }
2042 assert!(BedrockModelProvider::should_cache_conversation(&messages));
2043 }
2044
2045 #[test]
2048 fn convert_tools_to_converse_formats_correctly() {
2049 let tools = vec![ToolSpec {
2050 name: "shell".to_string(),
2051 description: "Run commands".to_string(),
2052 parameters: serde_json::json!({"type": "object", "properties": {"command": {"type": "string"}}}),
2053 }];
2054 let config = BedrockModelProvider::convert_tools_to_converse(Some(&tools));
2055 assert!(config.is_some());
2056 let config = config.unwrap();
2057 assert_eq!(config.tools.len(), 1);
2058 assert_eq!(config.tools[0].tool_spec.name, "shell");
2059 }
2060
2061 #[test]
2062 fn convert_tools_to_converse_empty_returns_none() {
2063 assert!(BedrockModelProvider::convert_tools_to_converse(Some(&[])).is_none());
2064 assert!(BedrockModelProvider::convert_tools_to_converse(None).is_none());
2065 }
2066
2067 #[test]
2070 fn converse_request_serializes_without_system() {
2071 let req = ConverseRequest {
2072 system: None,
2073 messages: vec![ConverseMessage {
2074 role: "user".to_string(),
2075 content: vec![ContentBlock::Text(TextBlock {
2076 text: "Hello".to_string(),
2077 })],
2078 }],
2079 inference_config: Some(InferenceConfig {
2080 max_tokens: 4096,
2081 temperature: Some(0.7),
2082 }),
2083 tool_config: None,
2084 additional_model_request_fields: None,
2085 };
2086 let json = serde_json::to_string(&req).unwrap();
2087 assert!(!json.contains("system"));
2088 assert!(json.contains("Hello"));
2089 assert!(json.contains("maxTokens"));
2090 }
2091
2092 #[test]
2095 fn bedrock_model_omits_temperature_matches_opus_4_7() {
2096 assert!(bedrock_model_omits_temperature(
2097 "us.anthropic.claude-opus-4-7"
2098 ));
2099 assert!(bedrock_model_omits_temperature(
2100 "anthropic.claude-opus-4-7-v1:0"
2101 ));
2102 }
2103
2104 #[test]
2105 fn bedrock_model_omits_temperature_skips_other_models() {
2106 assert!(!bedrock_model_omits_temperature(
2107 "us.anthropic.claude-opus-4-6-v1"
2108 ));
2109 assert!(!bedrock_model_omits_temperature(
2110 "us.anthropic.claude-sonnet-4-6-v1"
2111 ));
2112 assert!(!bedrock_model_omits_temperature(
2113 "us.anthropic.claude-haiku-4-5-v1"
2114 ));
2115 }
2116
2117 #[test]
2118 fn bedrock_model_supports_native_thinking_excludes_opus_4_7() {
2119 assert!(!bedrock_model_supports_native_thinking(
2122 "us.anthropic.claude-opus-4-7"
2123 ));
2124 assert!(!bedrock_model_supports_native_thinking(
2125 "anthropic.claude-opus-4-7-v1:0"
2126 ));
2127 }
2128
2129 #[test]
2130 fn bedrock_model_supports_native_thinking_allows_other_models() {
2131 assert!(bedrock_model_supports_native_thinking(
2132 "us.anthropic.claude-opus-4-6-v1"
2133 ));
2134 assert!(bedrock_model_supports_native_thinking(
2135 "us.anthropic.claude-sonnet-4-6-v1"
2136 ));
2137 assert!(bedrock_model_supports_native_thinking(
2138 "us.anthropic.claude-haiku-4-5-v1"
2139 ));
2140 }
2141
2142 #[test]
2143 fn inference_config_serializes_without_temperature_when_none() {
2144 let cfg = InferenceConfig {
2145 max_tokens: 4096,
2146 temperature: None,
2147 };
2148 let json = serde_json::to_string(&cfg).unwrap();
2149 assert!(json.contains("maxTokens"));
2150 assert!(
2151 !json.contains("temperature"),
2152 "expected temperature to be omitted, got: {json}"
2153 );
2154 }
2155
2156 #[test]
2157 fn inference_config_serializes_with_temperature_when_some() {
2158 let cfg = InferenceConfig {
2159 max_tokens: 4096,
2160 temperature: Some(0.7),
2161 };
2162 let json = serde_json::to_string(&cfg).unwrap();
2163 assert!(json.contains("maxTokens"));
2164 assert!(
2165 json.contains("temperature"),
2166 "expected temperature to be present, got: {json}"
2167 );
2168 }
2169
2170 #[test]
2171 fn converse_response_deserializes_text() {
2172 let json = r#"{
2173 "output": {
2174 "message": {
2175 "role": "assistant",
2176 "content": [{"text": "Hello from Bedrock"}]
2177 }
2178 },
2179 "stopReason": "end_turn"
2180 }"#;
2181 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2182 let parsed = BedrockModelProvider::parse_converse_response(resp);
2183 assert_eq!(parsed.text.as_deref(), Some("Hello from Bedrock"));
2184 assert!(parsed.tool_calls.is_empty());
2185 }
2186
2187 #[test]
2188 fn converse_response_deserializes_tool_use() {
2189 let json = r#"{
2190 "output": {
2191 "message": {
2192 "role": "assistant",
2193 "content": [
2194 {"toolUse": {"toolUseId": "call_1", "name": "shell", "input": {"command": "ls"}}}
2195 ]
2196 }
2197 },
2198 "stopReason": "tool_use"
2199 }"#;
2200 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2201 let parsed = BedrockModelProvider::parse_converse_response(resp);
2202 assert!(parsed.text.is_none());
2203 assert_eq!(parsed.tool_calls.len(), 1);
2204 assert_eq!(parsed.tool_calls[0].name, "shell");
2205 assert_eq!(parsed.tool_calls[0].id, "call_1");
2206 }
2207
2208 #[test]
2209 fn converse_response_empty_output() {
2210 let json = r#"{"output": null, "stopReason": null}"#;
2211 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2212 let parsed = BedrockModelProvider::parse_converse_response(resp);
2213 assert!(parsed.text.is_none());
2214 assert!(parsed.tool_calls.is_empty());
2215 }
2216
2217 #[test]
2218 fn content_block_text_serializes_as_flat_string() {
2219 let block = ContentBlock::Text(TextBlock {
2220 text: "Hello".to_string(),
2221 });
2222 let json = serde_json::to_string(&block).unwrap();
2223 assert_eq!(json, r#"{"text":"Hello"}"#);
2225 }
2226
2227 #[test]
2228 fn content_block_tool_use_serializes_with_nested_object() {
2229 let block = ContentBlock::ToolUse(ToolUseWrapper {
2230 tool_use: ToolUseBlock {
2231 tool_use_id: "call_1".to_string(),
2232 name: "shell".to_string(),
2233 input: serde_json::json!({"command": "ls"}),
2234 },
2235 });
2236 let json = serde_json::to_string(&block).unwrap();
2237 assert!(json.contains(r#""toolUse""#));
2238 assert!(json.contains(r#""toolUseId":"call_1""#));
2239 }
2240
2241 #[test]
2242 fn content_block_cache_point_serializes() {
2243 let block = ContentBlock::CachePointBlock(CachePointWrapper {
2244 cache_point: CachePoint::default_cache(),
2245 });
2246 let json = serde_json::to_string(&block).unwrap();
2247 assert_eq!(json, r#"{"cachePoint":{"type":"default"}}"#);
2248 }
2249
2250 #[test]
2251 fn content_block_text_round_trips() {
2252 let original = ContentBlock::Text(TextBlock {
2253 text: "Hello".to_string(),
2254 });
2255 let json = serde_json::to_string(&original).unwrap();
2256 let deserialized: ContentBlock = serde_json::from_str(&json).unwrap();
2257 assert!(matches!(deserialized, ContentBlock::Text(tb) if tb.text == "Hello"));
2258 }
2259
2260 #[test]
2261 fn cache_point_serializes() {
2262 let cp = CachePoint::default_cache();
2263 let json = serde_json::to_string(&cp).unwrap();
2264 assert_eq!(json, r#"{"type":"default"}"#);
2265 }
2266
2267 #[tokio::test]
2268 async fn warmup_without_credentials_is_noop() {
2269 let model_provider = BedrockModelProvider {
2270 alias: "test".to_string(),
2271 auth: None,
2272 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2273 cred_cache: Mutex::new(None),
2274 };
2275 let result = model_provider.warmup().await;
2276 assert!(result.is_ok());
2277 }
2278
2279 #[test]
2280 fn capabilities_reports_native_tool_calling() {
2281 let model_provider = BedrockModelProvider {
2282 alias: "test".to_string(),
2283 auth: None,
2284 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2285 cred_cache: Mutex::new(None),
2286 };
2287 let caps = model_provider.capabilities();
2288 assert!(caps.native_tool_calling);
2289 }
2290
2291 #[test]
2292 fn converse_response_parses_usage() {
2293 let json = r#"{
2294 "output": {"message": {"role": "assistant", "content": [{"text": {"text": "Hello"}}]}},
2295 "usage": {"inputTokens": 500, "outputTokens": 100}
2296 }"#;
2297 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2298 let usage = resp.usage.unwrap();
2299 assert_eq!(usage.input_tokens, Some(500));
2300 assert_eq!(usage.output_tokens, Some(100));
2301 }
2302
2303 #[test]
2304 fn converse_response_parses_without_usage() {
2305 let json = r#"{"output": {"message": {"role": "assistant", "content": []}}}"#;
2306 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
2307 assert!(resp.usage.is_none());
2308 }
2309
2310 #[test]
2313 fn fallback_tool_result_emits_tool_result_block_not_text() {
2314 let messages = vec![
2317 ChatMessage::user("do something"),
2318 ChatMessage::assistant(
2319 r#"{"content":"","tool_calls":[{"id":"tool_1","name":"shell","arguments":"{}"}]}"#,
2320 ),
2321 ChatMessage {
2322 role: "tool".to_string(),
2323 content: "not valid json".to_string(),
2324 },
2325 ];
2326 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
2327 let tool_msg = &msgs[2];
2328 assert_eq!(tool_msg.role, "user");
2329 assert!(
2330 matches!(&tool_msg.content[0], ContentBlock::ToolResult(_)),
2331 "Expected ToolResult block, got {:?}",
2332 tool_msg.content[0]
2333 );
2334 }
2335
2336 #[test]
2337 fn fallback_recovers_tool_use_id_from_assistant() {
2338 let messages = vec![
2339 ChatMessage::user("run it"),
2340 ChatMessage::assistant(
2341 r#"{"content":"","tool_calls":[{"id":"tool_abc","name":"shell","arguments":"{}"}]}"#,
2342 ),
2343 ChatMessage {
2344 role: "tool".to_string(),
2345 content: "raw output with no json".to_string(),
2346 },
2347 ];
2348 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
2349 if let ContentBlock::ToolResult(ref wrapper) = msgs[2].content[0] {
2350 assert_eq!(wrapper.tool_result.tool_use_id, "tool_abc");
2351 assert_eq!(wrapper.tool_result.status, "error");
2352 } else {
2353 panic!("Expected ToolResult block");
2354 }
2355 }
2356
2357 #[test]
2358 fn consecutive_tool_results_merged_into_single_message() {
2359 let messages = vec![
2360 ChatMessage::user("do two things"),
2361 ChatMessage::assistant(
2362 r#"{"content":"","tool_calls":[{"id":"t1","name":"a","arguments":"{}"},{"id":"t2","name":"b","arguments":"{}"}]}"#,
2363 ),
2364 ChatMessage::tool(r#"{"tool_call_id":"t1","content":"result 1"}"#),
2365 ChatMessage::tool(r#"{"tool_call_id":"t2","content":"result 2"}"#),
2366 ];
2367 let (_, msgs) = BedrockModelProvider::convert_messages(&messages);
2368 assert_eq!(msgs.len(), 3, "Expected 3 messages, got {}", msgs.len());
2370 assert_eq!(msgs[2].role, "user");
2371 assert_eq!(
2372 msgs[2].content.len(),
2373 2,
2374 "Expected 2 tool results in one message"
2375 );
2376 assert!(matches!(&msgs[2].content[0], ContentBlock::ToolResult(_)));
2377 assert!(matches!(&msgs[2].content[1], ContentBlock::ToolResult(_)));
2378 }
2379
2380 #[test]
2381 fn extract_tool_call_id_tries_multiple_field_names() {
2382 assert_eq!(
2383 BedrockModelProvider::extract_tool_call_id(r#"{"tool_call_id":"a"}"#),
2384 Some("a".to_string())
2385 );
2386 assert_eq!(
2387 BedrockModelProvider::extract_tool_call_id(r#"{"tool_use_id":"b"}"#),
2388 Some("b".to_string())
2389 );
2390 assert_eq!(
2391 BedrockModelProvider::extract_tool_call_id(r#"{"toolUseId":"c"}"#),
2392 Some("c".to_string())
2393 );
2394 assert_eq!(
2395 BedrockModelProvider::extract_tool_call_id("not json at all"),
2396 None
2397 );
2398 }
2399
2400 #[test]
2401 fn parse_tool_result_accepts_alternate_id_fields() {
2402 let msg = BedrockModelProvider::parse_tool_result_message(
2403 r#"{"tool_use_id":"x","content":"ok"}"#,
2404 );
2405 assert!(msg.is_some());
2406 if let ContentBlock::ToolResult(ref wrapper) = msg.unwrap().content[0] {
2407 assert_eq!(wrapper.tool_result.tool_use_id, "x");
2408 } else {
2409 panic!("Expected ToolResult");
2410 }
2411 }
2412
2413 #[test]
2414 fn sanitize_removes_empty_text_blocks() {
2415 let mut messages = vec![ConverseMessage {
2416 role: "assistant".to_string(),
2417 content: vec![ContentBlock::Text(TextBlock {
2418 text: String::new(),
2419 })],
2420 }];
2421 BedrockModelProvider::sanitize_empty_content_blocks(&mut messages);
2422 assert_eq!(messages.len(), 1);
2423 if let ContentBlock::Text(ref tb) = messages[0].content[0] {
2424 assert_eq!(tb.text, "(empty)");
2425 } else {
2426 panic!("Expected Text block with placeholder");
2427 }
2428 }
2429
2430 #[test]
2431 fn sanitize_preserves_non_empty_text_blocks() {
2432 let mut messages = vec![ConverseMessage {
2433 role: "user".to_string(),
2434 content: vec![ContentBlock::Text(TextBlock {
2435 text: "Hello".to_string(),
2436 })],
2437 }];
2438 BedrockModelProvider::sanitize_empty_content_blocks(&mut messages);
2439 if let ContentBlock::Text(ref tb) = messages[0].content[0] {
2440 assert_eq!(tb.text, "Hello");
2441 } else {
2442 panic!("Expected preserved Text block");
2443 }
2444 }
2445
2446 #[test]
2447 fn convert_messages_empty_assistant_gets_placeholder() {
2448 let messages = vec![
2449 ChatMessage::user("Hello"),
2450 ChatMessage {
2451 role: "assistant".to_string(),
2452 content: String::new(),
2453 },
2454 ChatMessage::user("Continue"),
2455 ];
2456 let (_, converse) = BedrockModelProvider::convert_messages(&messages);
2457 let assistant_msg = &converse[1];
2458 assert_eq!(assistant_msg.role, "assistant");
2459 if let ContentBlock::Text(ref tb) = assistant_msg.content[0] {
2460 assert!(!tb.text.is_empty(), "Assistant text should not be empty");
2461 } else {
2462 panic!("Expected Text block for assistant message");
2463 }
2464 }
2465
2466 #[test]
2469 fn parse_aws_config_default_profile() {
2470 let config = "\
2471[default]
2472region=us-west-2
2473credential_process=ada credentials print --account=123 --provider=conduit --role=MyRole
2474";
2475 let result = AwsCredentials::parse_aws_config(config, "default");
2476 assert!(result.is_some());
2477 let (cmd, region) = result.unwrap();
2478 assert_eq!(
2479 cmd,
2480 "ada credentials print --account=123 --provider=conduit --role=MyRole"
2481 );
2482 assert_eq!(region.as_deref(), Some("us-west-2"));
2483 }
2484
2485 #[test]
2486 fn parse_aws_config_named_profile() {
2487 let config = "\
2488[default]
2489region=us-east-1
2490
2491[profile myprofile]
2492region=eu-west-1
2493credential_process=aws sso get-role-credentials --profile myprofile
2494";
2495 let result = AwsCredentials::parse_aws_config(config, "myprofile");
2496 assert!(result.is_some());
2497 let (cmd, region) = result.unwrap();
2498 assert!(cmd.contains("myprofile"));
2499 assert_eq!(region.as_deref(), Some("eu-west-1"));
2500 }
2501
2502 #[test]
2503 fn parse_aws_config_missing_credential_process() {
2504 let config = "\
2505[default]
2506region=us-west-2
2507";
2508 let result = AwsCredentials::parse_aws_config(config, "default");
2509 assert!(result.is_none());
2510 }
2511
2512 #[test]
2513 fn parse_aws_config_ignores_comments() {
2514 let config = "\
2515[default]
2516# credential_process=should-be-ignored
2517; credential_process=also-ignored
2518credential_process=real-command
2519";
2520 let result = AwsCredentials::parse_aws_config(config, "default");
2521 assert!(result.is_some());
2522 assert_eq!(result.unwrap().0, "real-command");
2523 }
2524
2525 #[test]
2526 fn parse_aws_config_nonexistent_profile() {
2527 let config = "\
2528[default]
2529credential_process=some-command
2530";
2531 let result = AwsCredentials::parse_aws_config(config, "nonexistent");
2532 assert!(result.is_none());
2533 }
2534
2535 #[test]
2536 fn from_credential_process_parses_json_output() {
2537 let config = "\
2539[default]
2540credential_process=echo '{\"Version\":1,\"AccessKeyId\":\"AKIA\",\"SecretAccessKey\":\"secret\",\"SessionToken\":\"tok\"}'
2541region=ap-southeast-1
2542";
2543 let (cmd, region) = AwsCredentials::parse_aws_config(config, "default").unwrap();
2544 assert!(cmd.starts_with("echo"));
2545 assert_eq!(region.as_deref(), Some("ap-southeast-1"));
2546
2547 let output = std::process::Command::new("sh")
2548 .args(["-c", &cmd])
2549 .output()
2550 .unwrap();
2551 let json: serde_json::Value = serde_json::from_slice(&output.stdout).unwrap();
2552 assert_eq!(json["AccessKeyId"].as_str(), Some("AKIA"));
2553 assert_eq!(json["SecretAccessKey"].as_str(), Some("secret"));
2554 assert_eq!(json["SessionToken"].as_str(), Some("tok"));
2555 }
2556
2557 #[test]
2558 fn env_vars_take_precedence_over_credential_process() {
2559 let _env_lock = env_lock();
2560 let _ak = EnvGuard::set("AWS_ACCESS_KEY_ID", Some("FROM_ENV"));
2561 let _sk = EnvGuard::set("AWS_SECRET_ACCESS_KEY", Some("secret_from_env"));
2562
2563 let creds = AwsCredentials::from_env();
2564 assert!(creds.is_ok());
2565 assert_eq!(creds.unwrap().access_key_id, "FROM_ENV");
2566 }
2567
2568 fn make_creds(expires_at: Option<chrono::DateTime<chrono::Utc>>) -> AwsCredentials {
2571 AwsCredentials {
2572 access_key_id: "AKIA".to_string(),
2573 secret_access_key: "secret".to_string(),
2574 session_token: Some("tok".to_string()),
2575 region: "us-west-2".to_string(),
2576 expires_at,
2577 }
2578 }
2579
2580 #[test]
2581 fn is_expired_returns_false_when_no_expiry() {
2582 let creds = make_creds(None);
2583 assert!(!creds.is_expired());
2584 }
2585
2586 #[test]
2587 fn is_expired_returns_false_when_future() {
2588 let future = chrono::Utc::now() + chrono::Duration::hours(1);
2589 let creds = make_creds(Some(future));
2590 assert!(!creds.is_expired());
2591 }
2592
2593 #[test]
2594 fn is_expired_returns_true_when_past() {
2595 let past = chrono::Utc::now() - chrono::Duration::hours(1);
2596 let creds = make_creds(Some(past));
2597 assert!(creds.is_expired());
2598 }
2599
2600 #[test]
2601 fn is_expired_returns_true_within_skew_window() {
2602 let soon = chrono::Utc::now() + chrono::Duration::seconds(30);
2604 let creds = make_creds(Some(soon));
2605 assert!(creds.is_expired());
2606 }
2607
2608 #[test]
2609 fn cached_credentials_returns_none_when_empty() {
2610 let model_provider = BedrockModelProvider {
2611 alias: "test".to_string(),
2612 auth: None,
2613 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2614 cred_cache: Mutex::new(None),
2615 };
2616 assert!(model_provider.cached_credentials().is_none());
2617 }
2618
2619 #[test]
2620 fn cached_credentials_returns_some_when_valid() {
2621 let future = chrono::Utc::now() + chrono::Duration::hours(1);
2622 let model_provider = BedrockModelProvider {
2623 alias: "test".to_string(),
2624 auth: None,
2625 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2626 cred_cache: Mutex::new(Some(make_creds(Some(future)))),
2627 };
2628 let cached = model_provider.cached_credentials();
2629 assert!(cached.is_some());
2630 assert_eq!(cached.unwrap().access_key_id, "AKIA");
2631 }
2632
2633 #[test]
2634 fn cached_credentials_returns_none_when_expired() {
2635 let past = chrono::Utc::now() - chrono::Duration::hours(1);
2636 let model_provider = BedrockModelProvider {
2637 alias: "test".to_string(),
2638 auth: None,
2639 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2640 cred_cache: Mutex::new(Some(make_creds(Some(past)))),
2641 };
2642 assert!(model_provider.cached_credentials().is_none());
2643 }
2644
2645 #[test]
2646 fn cache_credentials_stores_and_retrieves() {
2647 let future = chrono::Utc::now() + chrono::Duration::hours(1);
2648 let model_provider = BedrockModelProvider {
2649 alias: "test".to_string(),
2650 auth: None,
2651 max_tokens: zeroclaw_api::model_provider::BASELINE_MAX_TOKENS,
2652 cred_cache: Mutex::new(None),
2653 };
2654 assert!(model_provider.cached_credentials().is_none());
2655 model_provider.cache_credentials(&make_creds(Some(future)));
2656 assert!(model_provider.cached_credentials().is_some());
2657 }
2658}