1use crate::traits::{
15 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
16 ModelProvider, TokenUsage, ToolCall as ProviderToolCall,
17};
18use async_trait::async_trait;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::sync::Mutex;
25use zeroclaw_api::tool::ToolSpec;
26
27const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
29const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
30const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
31const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
32const DEFAULT_API: &str = "https://api.githubcopilot.com";
33
34#[derive(Debug, Deserialize)]
37struct DeviceCodeResponse {
38 device_code: String,
39 user_code: String,
40 verification_uri: String,
41 #[serde(default = "default_interval")]
42 interval: u64,
43 #[serde(default = "default_expires_in")]
44 expires_in: u64,
45}
46
47fn default_interval() -> u64 {
48 5
49}
50
51fn default_expires_in() -> u64 {
52 900
53}
54
55#[derive(Debug, Deserialize)]
56struct AccessTokenResponse {
57 access_token: Option<String>,
58 error: Option<String>,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62struct ApiKeyInfo {
63 token: String,
64 expires_at: i64,
65 #[serde(default)]
66 endpoints: Option<ApiEndpoints>,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70struct ApiEndpoints {
71 api: Option<String>,
72}
73
74struct CachedApiKey {
75 token: String,
76 api_endpoint: String,
77 expires_at: i64,
78}
79
80#[derive(Debug, Serialize)]
83struct ApiChatRequest<'a> {
84 model: String,
85 messages: Vec<ApiMessage>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 temperature: Option<f64>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 tools: Option<Vec<NativeToolSpec<'a>>>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 tool_choice: Option<String>,
92}
93
94#[derive(Debug, Serialize)]
95struct ApiMessage {
96 role: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 content: Option<ApiContent>,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 tool_call_id: Option<String>,
101 #[serde(skip_serializing_if = "Option::is_none")]
102 tool_calls: Option<Vec<NativeToolCall>>,
103}
104
105#[derive(Debug, Serialize)]
106struct NativeToolSpec<'a> {
107 #[serde(rename = "type")]
108 kind: &'static str,
109 function: NativeToolFunctionSpec<'a>,
110}
111
112#[derive(Debug, Serialize)]
113struct NativeToolFunctionSpec<'a> {
114 name: &'a str,
115 description: &'a str,
116 parameters: &'a serde_json::Value,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct NativeToolCall {
121 #[serde(skip_serializing_if = "Option::is_none")]
122 id: Option<String>,
123 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
124 kind: Option<String>,
125 function: NativeFunctionCall,
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129struct NativeFunctionCall {
130 name: String,
131 arguments: String,
132}
133
134#[derive(Debug, Clone, Serialize)]
136#[serde(untagged)]
137enum ApiContent {
138 Text(String),
139 Parts(Vec<ContentPart>),
140}
141
142#[derive(Debug, Clone, Serialize)]
143#[serde(tag = "type")]
144enum ContentPart {
145 #[serde(rename = "text")]
146 Text { text: String },
147 #[serde(rename = "image_url")]
148 ImageUrl { image_url: ImageUrlDetail },
149}
150
151#[derive(Debug, Clone, Serialize)]
152struct ImageUrlDetail {
153 url: String,
154}
155
156#[derive(Debug, Deserialize)]
157struct ApiChatResponse {
158 choices: Vec<Choice>,
159 #[serde(default)]
160 usage: Option<UsageInfo>,
161}
162
163#[derive(Debug, Deserialize)]
164struct UsageInfo {
165 #[serde(default)]
166 prompt_tokens: Option<u64>,
167 #[serde(default)]
168 completion_tokens: Option<u64>,
169}
170
171#[derive(Debug, Deserialize)]
172struct Choice {
173 message: ResponseMessage,
174}
175
176#[derive(Debug, Deserialize)]
177struct ResponseMessage {
178 #[serde(default)]
179 content: Option<String>,
180 #[serde(default)]
181 tool_calls: Option<Vec<NativeToolCall>>,
182}
183
184pub struct CopilotModelProvider {
192 alias: String,
194 github_token: Option<String>,
195 refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
198 token_dir: PathBuf,
199}
200
201impl CopilotModelProvider {
202 pub fn new(alias: &str, github_token: Option<&str>) -> Self {
203 let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
204 .map(|dir| dir.config_dir().join("copilot"))
205 .unwrap_or_else(|| {
206 let user = std::env::var("USER")
209 .or_else(|_| std::env::var("USERNAME"))
210 .unwrap_or_else(|_| "unknown".to_string());
211 std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
212 });
213
214 if let Err(err) = std::fs::create_dir_all(&token_dir) {
215 ::zeroclaw_log::record!(
216 WARN,
217 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
218 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
219 &format!(
220 "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
221 token_dir
222 )
223 );
224 } else {
225 #[cfg(unix)]
226 {
227 use std::os::unix::fs::PermissionsExt;
228
229 if let Err(err) =
230 std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
231 {
232 ::zeroclaw_log::record!(
233 WARN,
234 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
235 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
236 &format!(
237 "Failed to set Copilot token directory permissions on {:?}: {err}",
238 token_dir
239 )
240 );
241 }
242 }
243 }
244
245 Self {
246 alias: alias.to_string(),
247 github_token: github_token
248 .filter(|token| !token.is_empty())
249 .map(String::from),
250 refresh_lock: Arc::new(Mutex::new(None)),
251 token_dir,
252 }
253 }
254 fn http_client(&self) -> Client {
255 zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
256 "model_provider.copilot",
257 120,
258 10,
259 )
260 }
261
262 const COPILOT_HEADERS: [(&str, &str); 4] = [
264 ("Editor-Version", "vscode/1.85.1"),
265 ("Editor-Plugin-Version", "copilot/1.155.0"),
266 ("User-Agent", "GithubCopilot/1.155.0"),
267 ("Accept", "application/json"),
268 ];
269
270 fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec<'_>>> {
271 tools.map(|items| {
272 items
273 .iter()
274 .map(|tool| NativeToolSpec {
275 kind: "function",
276 function: NativeToolFunctionSpec {
277 name: &tool.name,
278 description: &tool.description,
279 parameters: &tool.parameters,
280 },
281 })
282 .collect()
283 })
284 }
285
286 fn to_api_content(role: &str, content: &str) -> Option<ApiContent> {
289 if role != "user" {
290 return Some(ApiContent::Text(content.to_string()));
291 }
292
293 let (cleaned_text, image_refs) = crate::multimodal::parse_image_markers(content);
294 if image_refs.is_empty() {
295 return Some(ApiContent::Text(content.to_string()));
296 }
297
298 let mut parts = Vec::with_capacity(image_refs.len() + 1);
299 let trimmed = cleaned_text.trim();
300 if !trimmed.is_empty() {
301 parts.push(ContentPart::Text {
302 text: trimmed.to_string(),
303 });
304 }
305 for image_ref in image_refs {
306 parts.push(ContentPart::ImageUrl {
307 image_url: ImageUrlDetail { url: image_ref },
308 });
309 }
310
311 Some(ApiContent::Parts(parts))
312 }
313
314 fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
315 messages
316 .iter()
317 .map(|message| {
318 if message.role == "assistant"
319 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
320 && let Some(tool_calls_value) = value.get("tool_calls")
321 && let Ok(parsed_calls) =
322 serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
323 {
324 let tool_calls = parsed_calls
325 .into_iter()
326 .map(|tool_call| NativeToolCall {
327 id: Some(tool_call.id),
328 kind: Some("function".to_string()),
329 function: NativeFunctionCall {
330 name: tool_call.name,
331 arguments: tool_call.arguments,
332 },
333 })
334 .collect::<Vec<_>>();
335
336 let content = value
337 .get("content")
338 .and_then(serde_json::Value::as_str)
339 .map(|s| ApiContent::Text(s.to_string()));
340
341 return ApiMessage {
342 role: "assistant".to_string(),
343 content,
344 tool_call_id: None,
345 tool_calls: Some(tool_calls),
346 };
347 }
348
349 if message.role == "tool"
350 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
351 {
352 let tool_call_id = value
353 .get("tool_call_id")
354 .and_then(serde_json::Value::as_str)
355 .map(ToString::to_string);
356 let content = value
357 .get("content")
358 .and_then(serde_json::Value::as_str)
359 .map(|s| ApiContent::Text(s.to_string()));
360
361 return ApiMessage {
362 role: "tool".to_string(),
363 content,
364 tool_call_id,
365 tool_calls: None,
366 };
367 }
368
369 ApiMessage {
370 role: message.role.clone(),
371 content: Self::to_api_content(&message.role, &message.content),
372 tool_call_id: None,
373 tool_calls: None,
374 }
375 })
376 .collect()
377 }
378
379 async fn send_chat_request(
381 &self,
382 messages: Vec<ApiMessage>,
383 tools: Option<&[ToolSpec]>,
384 model: &str,
385 temperature: Option<f64>,
386 ) -> anyhow::Result<ProviderChatResponse> {
387 let (token, endpoint) = self.get_api_key().await?;
388 let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
389
390 let native_tools = Self::convert_tools(tools);
391 let request = ApiChatRequest {
392 model: model.to_string(),
393 messages,
394 temperature,
395 tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
396 tools: native_tools,
397 };
398
399 let mut req = self
400 .http_client()
401 .post(&url)
402 .header("Authorization", format!("Bearer {token}"))
403 .json(&request);
404
405 for (header, value) in &Self::COPILOT_HEADERS {
406 req = req.header(*header, *value);
407 }
408
409 let response = req.send().await?;
410
411 if !response.status().is_success() {
412 return Err(super::api_error("GitHub Copilot", response).await);
413 }
414
415 let api_response: ApiChatResponse = response.json().await?;
416 let usage = api_response.usage.map(|u| TokenUsage {
417 input_tokens: u.prompt_tokens,
418 output_tokens: u.completion_tokens,
419 cached_input_tokens: None,
420 });
421 let choice = api_response.choices.into_iter().next().ok_or_else(|| {
422 ::zeroclaw_log::record!(
423 ERROR,
424 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
425 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
426 "copilot: empty choices in response"
427 );
428 anyhow::Error::msg("No response from GitHub Copilot")
429 })?;
430
431 let tool_calls = choice
432 .message
433 .tool_calls
434 .unwrap_or_default()
435 .into_iter()
436 .map(|tool_call| ProviderToolCall {
437 id: tool_call
438 .id
439 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
440 name: tool_call.function.name,
441 arguments: tool_call.function.arguments,
442 extra_content: None,
443 })
444 .collect();
445
446 Ok(ProviderChatResponse {
447 text: choice.message.content,
448 tool_calls,
449 usage,
450 reasoning_content: None,
451 })
452 }
453
454 async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
457 let mut cached = self.refresh_lock.lock().await;
458
459 if let Some(cached_key) = cached.as_ref()
460 && chrono::Utc::now().timestamp() + 120 < cached_key.expires_at
461 {
462 return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
463 }
464
465 if let Some(info) = self.load_api_key_from_disk().await
466 && chrono::Utc::now().timestamp() + 120 < info.expires_at
467 {
468 let endpoint = info
469 .endpoints
470 .as_ref()
471 .and_then(|e| e.api.clone())
472 .unwrap_or_else(|| DEFAULT_API.to_string());
473 let token = info.token;
474
475 *cached = Some(CachedApiKey {
476 token: token.clone(),
477 api_endpoint: endpoint.clone(),
478 expires_at: info.expires_at,
479 });
480 return Ok((token, endpoint));
481 }
482
483 let access_token = self.get_github_access_token().await?;
484 let api_key_info = self.exchange_for_api_key(&access_token).await?;
485 self.save_api_key_to_disk(&api_key_info).await;
486
487 let endpoint = api_key_info
488 .endpoints
489 .as_ref()
490 .and_then(|e| e.api.clone())
491 .unwrap_or_else(|| DEFAULT_API.to_string());
492
493 *cached = Some(CachedApiKey {
494 token: api_key_info.token.clone(),
495 api_endpoint: endpoint.clone(),
496 expires_at: api_key_info.expires_at,
497 });
498
499 Ok((api_key_info.token, endpoint))
500 }
501
502 async fn get_github_access_token(&self) -> anyhow::Result<String> {
504 if let Some(token) = &self.github_token {
505 return Ok(token.clone());
506 }
507
508 let access_token_path = self.token_dir.join("access-token");
509 if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
510 let token = cached.trim();
511 if !token.is_empty() {
512 return Ok(token.to_string());
513 }
514 }
515
516 let token = self.device_code_login().await?;
517 write_file_secure(&access_token_path, &token).await;
518 Ok(token)
519 }
520
521 async fn device_code_login(&self) -> anyhow::Result<String> {
523 let response: DeviceCodeResponse = self
524 .http_client()
525 .post(GITHUB_DEVICE_CODE_URL)
526 .header("Accept", "application/json")
527 .json(&serde_json::json!({
528 "client_id": GITHUB_CLIENT_ID,
529 "scope": "read:user"
530 }))
531 .send()
532 .await?
533 .error_for_status()?
534 .json()
535 .await?;
536
537 let mut poll_interval = Duration::from_secs(response.interval.max(5));
538 let expires_in = response.expires_in.max(1);
539 let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
540
541 eprintln!(
542 "\nGitHub Copilot authentication is required.\n\
543 Visit: {}\n\
544 Code: {}\n\
545 Waiting for authorization...\n",
546 response.verification_uri, response.user_code
547 );
548
549 while tokio::time::Instant::now() < expires_at {
550 tokio::time::sleep(poll_interval).await;
551
552 let token_response: AccessTokenResponse = self
553 .http_client()
554 .post(GITHUB_ACCESS_TOKEN_URL)
555 .header("Accept", "application/json")
556 .json(&serde_json::json!({
557 "client_id": GITHUB_CLIENT_ID,
558 "device_code": response.device_code,
559 "grant_type": "urn:ietf:params:oauth:grant-type:device_code"
560 }))
561 .send()
562 .await?
563 .json()
564 .await?;
565
566 if let Some(token) = token_response.access_token {
567 eprintln!("Authentication succeeded.\n");
568 return Ok(token);
569 }
570
571 match token_response.error.as_deref() {
572 Some("slow_down") => {
573 poll_interval += Duration::from_secs(5);
574 }
575 Some("authorization_pending") | None => {}
576 Some("expired_token") => {
577 anyhow::bail!("GitHub device authorization expired")
578 }
579 Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
580 }
581 }
582
583 anyhow::bail!("Timed out waiting for GitHub authorization")
584 }
585
586 async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
588 let mut request = self.http_client().get(GITHUB_API_KEY_URL);
589 for (header, value) in &Self::COPILOT_HEADERS {
590 request = request.header(*header, *value);
591 }
592 request = request.header("Authorization", format!("token {access_token}"));
593
594 let response = request.send().await?;
595
596 if !response.status().is_success() {
597 let status = response.status();
598 let body = response.text().await.unwrap_or_default();
599 let sanitized = super::sanitize_api_error(&body);
600
601 if status.as_u16() == 401 || status.as_u16() == 403 {
602 let access_token_path = self.token_dir.join("access-token");
603 tokio::fs::remove_file(&access_token_path).await.ok();
604 }
605
606 anyhow::bail!(
607 "Failed to get Copilot API key ({status}): {sanitized}. \
608 Ensure your GitHub account has an active Copilot subscription."
609 );
610 }
611
612 let info: ApiKeyInfo = response.json().await?;
613 Ok(info)
614 }
615
616 async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
617 let path = self.token_dir.join("api-key.json");
618 let data = tokio::fs::read_to_string(&path).await.ok()?;
619 serde_json::from_str(&data).ok()
620 }
621
622 async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
623 let path = self.token_dir.join("api-key.json");
624 if let Ok(json) = serde_json::to_string_pretty(info) {
625 write_file_secure(&path, &json).await;
626 }
627 }
628}
629
630async fn write_file_secure(path: &Path, content: &str) {
633 let path = path.to_path_buf();
634 let content = content.to_string();
635
636 let result = tokio::task::spawn_blocking(move || {
637 #[cfg(unix)]
638 {
639 use std::io::Write;
640 use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
641
642 let mut file = std::fs::OpenOptions::new()
643 .write(true)
644 .create(true)
645 .truncate(true)
646 .mode(0o600)
647 .open(&path)?;
648 file.write_all(content.as_bytes())?;
649
650 std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
651 Ok::<(), std::io::Error>(())
652 }
653 #[cfg(not(unix))]
654 {
655 std::fs::write(&path, &content)?;
656 Ok::<(), std::io::Error>(())
657 }
658 })
659 .await;
660
661 match result {
662 Ok(Ok(())) => {}
663 Ok(Err(err)) => ::zeroclaw_log::record!(
664 WARN,
665 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
666 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
667 .with_attrs(::serde_json::json!({"error": format!("{}", err)})),
668 "Failed to write secure file"
669 ),
670 Err(err) => ::zeroclaw_log::record!(
671 WARN,
672 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
673 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
674 .with_attrs(::serde_json::json!({"error": format!("{}", err)})),
675 "Failed to spawn blocking write"
676 ),
677 }
678}
679
680#[async_trait]
681impl ModelProvider for CopilotModelProvider {
682 fn default_base_url(&self) -> Option<&str> {
684 Some(DEFAULT_API)
685 }
686
687 async fn chat_with_system(
688 &self,
689 system_prompt: Option<&str>,
690 message: &str,
691 model: &str,
692 temperature: Option<f64>,
693 ) -> anyhow::Result<String> {
694 let mut messages = Vec::new();
695 if let Some(system) = system_prompt {
696 messages.push(ApiMessage {
697 role: "system".to_string(),
698 content: Some(ApiContent::Text(system.to_string())),
699 tool_call_id: None,
700 tool_calls: None,
701 });
702 }
703 messages.push(ApiMessage {
704 role: "user".to_string(),
705 content: Self::to_api_content("user", message),
706 tool_call_id: None,
707 tool_calls: None,
708 });
709
710 let response = self
711 .send_chat_request(messages, None, model, temperature)
712 .await?;
713 Ok(response.text.unwrap_or_default())
714 }
715
716 async fn chat_with_history(
717 &self,
718 messages: &[ChatMessage],
719 model: &str,
720 temperature: Option<f64>,
721 ) -> anyhow::Result<String> {
722 let response = self
723 .send_chat_request(Self::convert_messages(messages), None, model, temperature)
724 .await?;
725 Ok(response.text.unwrap_or_default())
726 }
727
728 async fn chat(
729 &self,
730 request: ProviderChatRequest<'_>,
731 model: &str,
732 temperature: Option<f64>,
733 ) -> anyhow::Result<ProviderChatResponse> {
734 self.send_chat_request(
735 Self::convert_messages(request.messages),
736 request.tools,
737 model,
738 temperature,
739 )
740 .await
741 }
742
743 fn supports_native_tools(&self) -> bool {
744 true
745 }
746
747 async fn warmup(&self) -> anyhow::Result<()> {
748 let _ = self.get_api_key().await?;
749 Ok(())
750 }
751}
752
753impl ::zeroclaw_api::attribution::Attributable for CopilotModelProvider {
754 fn role(&self) -> ::zeroclaw_api::attribution::Role {
755 ::zeroclaw_api::attribution::Role::Provider(
756 ::zeroclaw_api::attribution::ProviderKind::Model(
757 ::zeroclaw_api::attribution::ModelProviderKind::Copilot,
758 ),
759 )
760 }
761 fn alias(&self) -> &str {
762 &self.alias
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769
770 #[test]
771 fn new_without_token() {
772 let model_provider = CopilotModelProvider::new("test", None);
773 assert!(model_provider.github_token.is_none());
774 }
775
776 #[test]
777 fn new_with_token() {
778 let model_provider = CopilotModelProvider::new("test", Some("ghp_test"));
779 assert_eq!(model_provider.github_token.as_deref(), Some("ghp_test"));
780 }
781
782 #[test]
783 fn empty_token_treated_as_none() {
784 let model_provider = CopilotModelProvider::new("test", Some(""));
785 assert!(model_provider.github_token.is_none());
786 }
787
788 #[tokio::test]
789 async fn cache_starts_empty() {
790 let model_provider = CopilotModelProvider::new("test", None);
791 let cached = model_provider.refresh_lock.lock().await;
792 assert!(cached.is_none());
793 }
794
795 #[test]
796 fn copilot_headers_include_required_fields() {
797 let headers = CopilotModelProvider::COPILOT_HEADERS;
798 assert!(
799 headers
800 .iter()
801 .any(|(header, _)| *header == "Editor-Version")
802 );
803 assert!(
804 headers
805 .iter()
806 .any(|(header, _)| *header == "Editor-Plugin-Version")
807 );
808 assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
809 }
810
811 #[test]
812 fn default_interval_and_expiry() {
813 assert_eq!(default_interval(), 5);
814 assert_eq!(default_expires_in(), 900);
815 }
816
817 #[test]
818 fn supports_native_tools() {
819 let model_provider = CopilotModelProvider::new("test", None);
820 assert!(model_provider.supports_native_tools());
821 }
822
823 #[test]
824 fn api_response_parses_usage() {
825 let json = r#"{
826 "choices": [{"message": {"content": "Hello"}}],
827 "usage": {"prompt_tokens": 200, "completion_tokens": 80}
828 }"#;
829 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
830 let usage = resp.usage.unwrap();
831 assert_eq!(usage.prompt_tokens, Some(200));
832 assert_eq!(usage.completion_tokens, Some(80));
833 }
834
835 #[test]
836 fn api_response_parses_without_usage() {
837 let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#;
838 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
839 assert!(resp.usage.is_none());
840 }
841
842 #[test]
843 fn to_api_content_user_with_image_returns_parts() {
844 let content = "describe this [IMAGE:data:image/png;base64,abc123]";
845 let result = CopilotModelProvider::to_api_content("user", content).unwrap();
846 match result {
847 ApiContent::Parts(parts) => {
848 assert_eq!(parts.len(), 2);
849 assert!(matches!(&parts[0], ContentPart::Text { text } if text == "describe this"));
850 assert!(
851 matches!(&parts[1], ContentPart::ImageUrl { image_url } if image_url.url == "data:image/png;base64,abc123")
852 );
853 }
854 ApiContent::Text(_) => {
855 panic!("expected ApiContent::Parts for user message with image marker")
856 }
857 }
858 }
859
860 #[test]
861 fn to_api_content_user_plain_returns_text() {
862 let result = CopilotModelProvider::to_api_content("user", "hello world").unwrap();
863 assert!(matches!(result, ApiContent::Text(ref s) if s == "hello world"));
864 }
865
866 #[test]
867 fn to_api_content_non_user_returns_text() {
868 let result = CopilotModelProvider::to_api_content("system", "you are helpful").unwrap();
869 assert!(matches!(result, ApiContent::Text(ref s) if s == "you are helpful"));
870
871 let result = CopilotModelProvider::to_api_content("assistant", "sure").unwrap();
872 assert!(matches!(result, ApiContent::Text(ref s) if s == "sure"));
873 }
874}