Skip to main content

zeroclaw_providers/
copilot.rs

1//! GitHub Copilot model_provider with OAuth device-flow authentication.
2//!
3//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
4//! then exchanges the OAuth token for short-lived Copilot API keys.
5//! Tokens are cached to disk and auto-refreshed.
6//!
7//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
8//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
9//! and other third-party Copilot integrations. The Copilot token endpoint is
10//! private; there is no public OAuth scope or app registration for it.
11//! GitHub could change or revoke this at any time, which would break all
12//! third-party integrations simultaneously.
13
14use crate::traits::{
15    ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
16    ModelProvider, TokenUsage, ToolCall as ProviderToolCall,
17};
18use async_trait::async_trait;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::sync::Mutex;
25use zeroclaw_api::tool::ToolSpec;
26
27/// GitHub OAuth client ID for Copilot (VS Code extension).
28const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
29const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
30const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
31const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
32const DEFAULT_API: &str = "https://api.githubcopilot.com";
33
34// ── Token types ──────────────────────────────────────────────────
35
36#[derive(Debug, Deserialize)]
37struct DeviceCodeResponse {
38    device_code: String,
39    user_code: String,
40    verification_uri: String,
41    #[serde(default = "default_interval")]
42    interval: u64,
43    #[serde(default = "default_expires_in")]
44    expires_in: u64,
45}
46
47fn default_interval() -> u64 {
48    5
49}
50
51fn default_expires_in() -> u64 {
52    900
53}
54
55#[derive(Debug, Deserialize)]
56struct AccessTokenResponse {
57    access_token: Option<String>,
58    error: Option<String>,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62struct ApiKeyInfo {
63    token: String,
64    expires_at: i64,
65    #[serde(default)]
66    endpoints: Option<ApiEndpoints>,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70struct ApiEndpoints {
71    api: Option<String>,
72}
73
74struct CachedApiKey {
75    token: String,
76    api_endpoint: String,
77    expires_at: i64,
78}
79
80// ── Chat completions types ───────────────────────────────────────
81
82#[derive(Debug, Serialize)]
83struct ApiChatRequest<'a> {
84    model: String,
85    messages: Vec<ApiMessage>,
86    temperature: f64,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    tools: Option<Vec<NativeToolSpec<'a>>>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    tool_choice: Option<String>,
91}
92
93#[derive(Debug, Serialize)]
94struct ApiMessage {
95    role: String,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    content: Option<ApiContent>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    tool_call_id: Option<String>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    tool_calls: Option<Vec<NativeToolCall>>,
102}
103
104#[derive(Debug, Serialize)]
105struct NativeToolSpec<'a> {
106    #[serde(rename = "type")]
107    kind: &'static str,
108    function: NativeToolFunctionSpec<'a>,
109}
110
111#[derive(Debug, Serialize)]
112struct NativeToolFunctionSpec<'a> {
113    name: &'a str,
114    description: &'a str,
115    parameters: &'a serde_json::Value,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119struct NativeToolCall {
120    #[serde(skip_serializing_if = "Option::is_none")]
121    id: Option<String>,
122    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
123    kind: Option<String>,
124    function: NativeFunctionCall,
125}
126
127#[derive(Debug, Serialize, Deserialize)]
128struct NativeFunctionCall {
129    name: String,
130    arguments: String,
131}
132
133/// Multi-part content for vision messages (OpenAI format).
134#[derive(Debug, Clone, Serialize)]
135#[serde(untagged)]
136enum ApiContent {
137    Text(String),
138    Parts(Vec<ContentPart>),
139}
140
141#[derive(Debug, Clone, Serialize)]
142#[serde(tag = "type")]
143enum ContentPart {
144    #[serde(rename = "text")]
145    Text { text: String },
146    #[serde(rename = "image_url")]
147    ImageUrl { image_url: ImageUrlDetail },
148}
149
150#[derive(Debug, Clone, Serialize)]
151struct ImageUrlDetail {
152    url: String,
153}
154
155#[derive(Debug, Deserialize)]
156struct ApiChatResponse {
157    choices: Vec<Choice>,
158    #[serde(default)]
159    usage: Option<UsageInfo>,
160}
161
162#[derive(Debug, Deserialize)]
163struct UsageInfo {
164    #[serde(default)]
165    prompt_tokens: Option<u64>,
166    #[serde(default)]
167    completion_tokens: Option<u64>,
168}
169
170#[derive(Debug, Deserialize)]
171struct Choice {
172    message: ResponseMessage,
173}
174
175#[derive(Debug, Deserialize)]
176struct ResponseMessage {
177    #[serde(default)]
178    content: Option<String>,
179    #[serde(default)]
180    tool_calls: Option<Vec<NativeToolCall>>,
181}
182
183// ── ModelProvider ─────────────────────────────────────────────────────
184
185/// GitHub Copilot model_provider with automatic OAuth and token refresh.
186///
187/// On first use, prompts the user to visit github.com/login/device.
188/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
189/// automatically.
190pub struct CopilotModelProvider {
191    /// `[model_providers.<family>.<alias>]` config-key alias.
192    alias: String,
193    github_token: Option<String>,
194    /// Mutex ensures only one caller refreshes tokens at a time,
195    /// preventing duplicate device flow prompts or redundant API calls.
196    refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
197    token_dir: PathBuf,
198}
199
200impl CopilotModelProvider {
201    pub fn new(alias: &str, github_token: Option<&str>) -> Self {
202        let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
203            .map(|dir| dir.config_dir().join("copilot"))
204            .unwrap_or_else(|| {
205                // Fall back to a user-specific temp directory to avoid
206                // shared-directory symlink attacks.
207                let user = std::env::var("USER")
208                    .or_else(|_| std::env::var("USERNAME"))
209                    .unwrap_or_else(|_| "unknown".to_string());
210                std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
211            });
212
213        if let Err(err) = std::fs::create_dir_all(&token_dir) {
214            ::zeroclaw_log::record!(
215                WARN,
216                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
217                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
218                &format!(
219                    "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
220                    token_dir
221                )
222            );
223        } else {
224            #[cfg(unix)]
225            {
226                use std::os::unix::fs::PermissionsExt;
227
228                if let Err(err) =
229                    std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
230                {
231                    ::zeroclaw_log::record!(
232                        WARN,
233                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
234                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
235                        &format!(
236                            "Failed to set Copilot token directory permissions on {:?}: {err}",
237                            token_dir
238                        )
239                    );
240                }
241            }
242        }
243
244        Self {
245            alias: alias.to_string(),
246            github_token: github_token
247                .filter(|token| !token.is_empty())
248                .map(String::from),
249            refresh_lock: Arc::new(Mutex::new(None)),
250            token_dir,
251        }
252    }
253    fn http_client(&self) -> Client {
254        zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
255            "model_provider.copilot",
256            120,
257            10,
258        )
259    }
260
261    /// Required headers for Copilot API requests (editor identification).
262    const COPILOT_HEADERS: [(&str, &str); 4] = [
263        ("Editor-Version", "vscode/1.85.1"),
264        ("Editor-Plugin-Version", "copilot/1.155.0"),
265        ("User-Agent", "GithubCopilot/1.155.0"),
266        ("Accept", "application/json"),
267    ];
268
269    fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec<'_>>> {
270        tools.map(|items| {
271            items
272                .iter()
273                .map(|tool| NativeToolSpec {
274                    kind: "function",
275                    function: NativeToolFunctionSpec {
276                        name: &tool.name,
277                        description: &tool.description,
278                        parameters: &tool.parameters,
279                    },
280                })
281                .collect()
282        })
283    }
284
285    /// Convert message content to API format, with multi-part support for
286    /// user messages containing `[IMAGE:...]` markers.
287    fn to_api_content(role: &str, content: &str) -> Option<ApiContent> {
288        if role != "user" {
289            return Some(ApiContent::Text(content.to_string()));
290        }
291
292        let (cleaned_text, image_refs) = crate::multimodal::parse_image_markers(content);
293        if image_refs.is_empty() {
294            return Some(ApiContent::Text(content.to_string()));
295        }
296
297        let mut parts = Vec::with_capacity(image_refs.len() + 1);
298        let trimmed = cleaned_text.trim();
299        if !trimmed.is_empty() {
300            parts.push(ContentPart::Text {
301                text: trimmed.to_string(),
302            });
303        }
304        for image_ref in image_refs {
305            parts.push(ContentPart::ImageUrl {
306                image_url: ImageUrlDetail { url: image_ref },
307            });
308        }
309
310        Some(ApiContent::Parts(parts))
311    }
312
313    fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
314        messages
315            .iter()
316            .map(|message| {
317                if message.role == "assistant"
318                    && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
319                    && let Some(tool_calls_value) = value.get("tool_calls")
320                    && let Ok(parsed_calls) =
321                        serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
322                {
323                    let tool_calls = parsed_calls
324                        .into_iter()
325                        .map(|tool_call| NativeToolCall {
326                            id: Some(tool_call.id),
327                            kind: Some("function".to_string()),
328                            function: NativeFunctionCall {
329                                name: tool_call.name,
330                                arguments: tool_call.arguments,
331                            },
332                        })
333                        .collect::<Vec<_>>();
334
335                    let content = value
336                        .get("content")
337                        .and_then(serde_json::Value::as_str)
338                        .map(|s| ApiContent::Text(s.to_string()));
339
340                    return ApiMessage {
341                        role: "assistant".to_string(),
342                        content,
343                        tool_call_id: None,
344                        tool_calls: Some(tool_calls),
345                    };
346                }
347
348                if message.role == "tool"
349                    && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
350                {
351                    let tool_call_id = value
352                        .get("tool_call_id")
353                        .and_then(serde_json::Value::as_str)
354                        .map(ToString::to_string);
355                    let content = value
356                        .get("content")
357                        .and_then(serde_json::Value::as_str)
358                        .map(|s| ApiContent::Text(s.to_string()));
359
360                    return ApiMessage {
361                        role: "tool".to_string(),
362                        content,
363                        tool_call_id,
364                        tool_calls: None,
365                    };
366                }
367
368                ApiMessage {
369                    role: message.role.clone(),
370                    content: Self::to_api_content(&message.role, &message.content),
371                    tool_call_id: None,
372                    tool_calls: None,
373                }
374            })
375            .collect()
376    }
377
378    /// Send a chat completions request with required Copilot headers.
379    async fn send_chat_request(
380        &self,
381        messages: Vec<ApiMessage>,
382        tools: Option<&[ToolSpec]>,
383        model: &str,
384        temperature: f64,
385    ) -> anyhow::Result<ProviderChatResponse> {
386        let (token, endpoint) = self.get_api_key().await?;
387        let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
388
389        let native_tools = Self::convert_tools(tools);
390        let request = ApiChatRequest {
391            model: model.to_string(),
392            messages,
393            temperature,
394            tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
395            tools: native_tools,
396        };
397
398        let mut req = self
399            .http_client()
400            .post(&url)
401            .header("Authorization", format!("Bearer {token}"))
402            .json(&request);
403
404        for (header, value) in &Self::COPILOT_HEADERS {
405            req = req.header(*header, *value);
406        }
407
408        let response = req.send().await?;
409
410        if !response.status().is_success() {
411            return Err(super::api_error("GitHub Copilot", response).await);
412        }
413
414        let api_response: ApiChatResponse = response.json().await?;
415        let usage = api_response.usage.map(|u| TokenUsage {
416            input_tokens: u.prompt_tokens,
417            output_tokens: u.completion_tokens,
418            cached_input_tokens: None,
419        });
420        let choice = api_response.choices.into_iter().next().ok_or_else(|| {
421            ::zeroclaw_log::record!(
422                ERROR,
423                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
424                    .with_outcome(::zeroclaw_log::EventOutcome::Failure),
425                "copilot: empty choices in response"
426            );
427            anyhow::Error::msg("No response from GitHub Copilot")
428        })?;
429
430        let tool_calls = choice
431            .message
432            .tool_calls
433            .unwrap_or_default()
434            .into_iter()
435            .map(|tool_call| ProviderToolCall {
436                id: tool_call
437                    .id
438                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
439                name: tool_call.function.name,
440                arguments: tool_call.function.arguments,
441                extra_content: None,
442            })
443            .collect();
444
445        Ok(ProviderChatResponse {
446            text: choice.message.content,
447            tool_calls,
448            usage,
449            reasoning_content: None,
450        })
451    }
452
453    /// Get a valid Copilot API key, refreshing or re-authenticating as needed.
454    /// Uses a Mutex to ensure only one caller refreshes at a time.
455    async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
456        let mut cached = self.refresh_lock.lock().await;
457
458        if let Some(cached_key) = cached.as_ref()
459            && chrono::Utc::now().timestamp() + 120 < cached_key.expires_at
460        {
461            return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
462        }
463
464        if let Some(info) = self.load_api_key_from_disk().await
465            && chrono::Utc::now().timestamp() + 120 < info.expires_at
466        {
467            let endpoint = info
468                .endpoints
469                .as_ref()
470                .and_then(|e| e.api.clone())
471                .unwrap_or_else(|| DEFAULT_API.to_string());
472            let token = info.token;
473
474            *cached = Some(CachedApiKey {
475                token: token.clone(),
476                api_endpoint: endpoint.clone(),
477                expires_at: info.expires_at,
478            });
479            return Ok((token, endpoint));
480        }
481
482        let access_token = self.get_github_access_token().await?;
483        let api_key_info = self.exchange_for_api_key(&access_token).await?;
484        self.save_api_key_to_disk(&api_key_info).await;
485
486        let endpoint = api_key_info
487            .endpoints
488            .as_ref()
489            .and_then(|e| e.api.clone())
490            .unwrap_or_else(|| DEFAULT_API.to_string());
491
492        *cached = Some(CachedApiKey {
493            token: api_key_info.token.clone(),
494            api_endpoint: endpoint.clone(),
495            expires_at: api_key_info.expires_at,
496        });
497
498        Ok((api_key_info.token, endpoint))
499    }
500
501    /// Get a GitHub access token from config, cache, or device flow.
502    async fn get_github_access_token(&self) -> anyhow::Result<String> {
503        if let Some(token) = &self.github_token {
504            return Ok(token.clone());
505        }
506
507        let access_token_path = self.token_dir.join("access-token");
508        if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
509            let token = cached.trim();
510            if !token.is_empty() {
511                return Ok(token.to_string());
512            }
513        }
514
515        let token = self.device_code_login().await?;
516        write_file_secure(&access_token_path, &token).await;
517        Ok(token)
518    }
519
520    /// Run GitHub OAuth device code flow.
521    async fn device_code_login(&self) -> anyhow::Result<String> {
522        let response: DeviceCodeResponse = self
523            .http_client()
524            .post(GITHUB_DEVICE_CODE_URL)
525            .header("Accept", "application/json")
526            .json(&serde_json::json!({
527                "client_id": GITHUB_CLIENT_ID,
528                "scope": "read:user"
529            }))
530            .send()
531            .await?
532            .error_for_status()?
533            .json()
534            .await?;
535
536        let mut poll_interval = Duration::from_secs(response.interval.max(5));
537        let expires_in = response.expires_in.max(1);
538        let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
539
540        eprintln!(
541            "\nGitHub Copilot authentication is required.\n\
542             Visit: {}\n\
543             Code: {}\n\
544             Waiting for authorization...\n",
545            response.verification_uri, response.user_code
546        );
547
548        while tokio::time::Instant::now() < expires_at {
549            tokio::time::sleep(poll_interval).await;
550
551            let token_response: AccessTokenResponse = self
552                .http_client()
553                .post(GITHUB_ACCESS_TOKEN_URL)
554                .header("Accept", "application/json")
555                .json(&serde_json::json!({
556                    "client_id": GITHUB_CLIENT_ID,
557                    "device_code": response.device_code,
558                    "grant_type": "urn:ietf:params:oauth:grant-type:device_code"
559                }))
560                .send()
561                .await?
562                .json()
563                .await?;
564
565            if let Some(token) = token_response.access_token {
566                eprintln!("Authentication succeeded.\n");
567                return Ok(token);
568            }
569
570            match token_response.error.as_deref() {
571                Some("slow_down") => {
572                    poll_interval += Duration::from_secs(5);
573                }
574                Some("authorization_pending") | None => {}
575                Some("expired_token") => {
576                    anyhow::bail!("GitHub device authorization expired")
577                }
578                Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
579            }
580        }
581
582        anyhow::bail!("Timed out waiting for GitHub authorization")
583    }
584
585    /// Exchange a GitHub access token for a Copilot API key.
586    async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
587        let mut request = self.http_client().get(GITHUB_API_KEY_URL);
588        for (header, value) in &Self::COPILOT_HEADERS {
589            request = request.header(*header, *value);
590        }
591        request = request.header("Authorization", format!("token {access_token}"));
592
593        let response = request.send().await?;
594
595        if !response.status().is_success() {
596            let status = response.status();
597            let body = response.text().await.unwrap_or_default();
598            let sanitized = super::sanitize_api_error(&body);
599
600            if status.as_u16() == 401 || status.as_u16() == 403 {
601                let access_token_path = self.token_dir.join("access-token");
602                tokio::fs::remove_file(&access_token_path).await.ok();
603            }
604
605            anyhow::bail!(
606                "Failed to get Copilot API key ({status}): {sanitized}. \
607                 Ensure your GitHub account has an active Copilot subscription."
608            );
609        }
610
611        let info: ApiKeyInfo = response.json().await?;
612        Ok(info)
613    }
614
615    async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
616        let path = self.token_dir.join("api-key.json");
617        let data = tokio::fs::read_to_string(&path).await.ok()?;
618        serde_json::from_str(&data).ok()
619    }
620
621    async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
622        let path = self.token_dir.join("api-key.json");
623        if let Ok(json) = serde_json::to_string_pretty(info) {
624            write_file_secure(&path, &json).await;
625        }
626    }
627}
628
629/// Write a file with 0600 permissions (owner read/write only).
630/// Uses `spawn_blocking` to avoid blocking the async runtime.
631async fn write_file_secure(path: &Path, content: &str) {
632    let path = path.to_path_buf();
633    let content = content.to_string();
634
635    let result = tokio::task::spawn_blocking(move || {
636        #[cfg(unix)]
637        {
638            use std::io::Write;
639            use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
640
641            let mut file = std::fs::OpenOptions::new()
642                .write(true)
643                .create(true)
644                .truncate(true)
645                .mode(0o600)
646                .open(&path)?;
647            file.write_all(content.as_bytes())?;
648
649            std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
650            Ok::<(), std::io::Error>(())
651        }
652        #[cfg(not(unix))]
653        {
654            std::fs::write(&path, &content)?;
655            Ok::<(), std::io::Error>(())
656        }
657    })
658    .await;
659
660    match result {
661        Ok(Ok(())) => {}
662        Ok(Err(err)) => ::zeroclaw_log::record!(
663            WARN,
664            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
665                .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
666                .with_attrs(::serde_json::json!({"error": format!("{}", err)})),
667            "Failed to write secure file"
668        ),
669        Err(err) => ::zeroclaw_log::record!(
670            WARN,
671            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
672                .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
673                .with_attrs(::serde_json::json!({"error": format!("{}", err)})),
674            "Failed to spawn blocking write"
675        ),
676    }
677}
678
679#[async_trait]
680impl ModelProvider for CopilotModelProvider {
681    // ── ModelProvider-family defaults ──
682    fn default_base_url(&self) -> Option<&str> {
683        Some(DEFAULT_API)
684    }
685
686    async fn chat_with_system(
687        &self,
688        system_prompt: Option<&str>,
689        message: &str,
690        model: &str,
691        temperature: Option<f64>,
692    ) -> anyhow::Result<String> {
693        let temperature = temperature.unwrap_or(self.default_temperature());
694
695        let mut messages = Vec::new();
696        if let Some(system) = system_prompt {
697            messages.push(ApiMessage {
698                role: "system".to_string(),
699                content: Some(ApiContent::Text(system.to_string())),
700                tool_call_id: None,
701                tool_calls: None,
702            });
703        }
704        messages.push(ApiMessage {
705            role: "user".to_string(),
706            content: Self::to_api_content("user", message),
707            tool_call_id: None,
708            tool_calls: None,
709        });
710
711        let response = self
712            .send_chat_request(messages, None, model, temperature)
713            .await?;
714        Ok(response.text.unwrap_or_default())
715    }
716
717    async fn chat_with_history(
718        &self,
719        messages: &[ChatMessage],
720        model: &str,
721        temperature: Option<f64>,
722    ) -> anyhow::Result<String> {
723        let temperature = temperature.unwrap_or(self.default_temperature());
724        let response = self
725            .send_chat_request(Self::convert_messages(messages), None, model, temperature)
726            .await?;
727        Ok(response.text.unwrap_or_default())
728    }
729
730    async fn chat(
731        &self,
732        request: ProviderChatRequest<'_>,
733        model: &str,
734        temperature: Option<f64>,
735    ) -> anyhow::Result<ProviderChatResponse> {
736        let temperature = temperature.unwrap_or(self.default_temperature());
737        self.send_chat_request(
738            Self::convert_messages(request.messages),
739            request.tools,
740            model,
741            temperature,
742        )
743        .await
744    }
745
746    fn supports_native_tools(&self) -> bool {
747        true
748    }
749
750    async fn warmup(&self) -> anyhow::Result<()> {
751        let _ = self.get_api_key().await?;
752        Ok(())
753    }
754}
755
756impl ::zeroclaw_api::attribution::Attributable for CopilotModelProvider {
757    fn role(&self) -> ::zeroclaw_api::attribution::Role {
758        ::zeroclaw_api::attribution::Role::Provider(
759            ::zeroclaw_api::attribution::ProviderKind::Model(
760                ::zeroclaw_api::attribution::ModelProviderKind::Copilot,
761            ),
762        )
763    }
764    fn alias(&self) -> &str {
765        &self.alias
766    }
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772
773    #[test]
774    fn new_without_token() {
775        let model_provider = CopilotModelProvider::new("test", None);
776        assert!(model_provider.github_token.is_none());
777    }
778
779    #[test]
780    fn new_with_token() {
781        let model_provider = CopilotModelProvider::new("test", Some("ghp_test"));
782        assert_eq!(model_provider.github_token.as_deref(), Some("ghp_test"));
783    }
784
785    #[test]
786    fn empty_token_treated_as_none() {
787        let model_provider = CopilotModelProvider::new("test", Some(""));
788        assert!(model_provider.github_token.is_none());
789    }
790
791    #[tokio::test]
792    async fn cache_starts_empty() {
793        let model_provider = CopilotModelProvider::new("test", None);
794        let cached = model_provider.refresh_lock.lock().await;
795        assert!(cached.is_none());
796    }
797
798    #[test]
799    fn copilot_headers_include_required_fields() {
800        let headers = CopilotModelProvider::COPILOT_HEADERS;
801        assert!(
802            headers
803                .iter()
804                .any(|(header, _)| *header == "Editor-Version")
805        );
806        assert!(
807            headers
808                .iter()
809                .any(|(header, _)| *header == "Editor-Plugin-Version")
810        );
811        assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
812    }
813
814    #[test]
815    fn default_interval_and_expiry() {
816        assert_eq!(default_interval(), 5);
817        assert_eq!(default_expires_in(), 900);
818    }
819
820    #[test]
821    fn supports_native_tools() {
822        let model_provider = CopilotModelProvider::new("test", None);
823        assert!(model_provider.supports_native_tools());
824    }
825
826    #[test]
827    fn api_response_parses_usage() {
828        let json = r#"{
829            "choices": [{"message": {"content": "Hello"}}],
830            "usage": {"prompt_tokens": 200, "completion_tokens": 80}
831        }"#;
832        let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
833        let usage = resp.usage.unwrap();
834        assert_eq!(usage.prompt_tokens, Some(200));
835        assert_eq!(usage.completion_tokens, Some(80));
836    }
837
838    #[test]
839    fn api_response_parses_without_usage() {
840        let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#;
841        let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
842        assert!(resp.usage.is_none());
843    }
844
845    #[test]
846    fn to_api_content_user_with_image_returns_parts() {
847        let content = "describe this [IMAGE:data:image/png;base64,abc123]";
848        let result = CopilotModelProvider::to_api_content("user", content).unwrap();
849        match result {
850            ApiContent::Parts(parts) => {
851                assert_eq!(parts.len(), 2);
852                assert!(matches!(&parts[0], ContentPart::Text { text } if text == "describe this"));
853                assert!(
854                    matches!(&parts[1], ContentPart::ImageUrl { image_url } if image_url.url == "data:image/png;base64,abc123")
855                );
856            }
857            ApiContent::Text(_) => {
858                panic!("expected ApiContent::Parts for user message with image marker")
859            }
860        }
861    }
862
863    #[test]
864    fn to_api_content_user_plain_returns_text() {
865        let result = CopilotModelProvider::to_api_content("user", "hello world").unwrap();
866        assert!(matches!(result, ApiContent::Text(ref s) if s == "hello world"));
867    }
868
869    #[test]
870    fn to_api_content_non_user_returns_text() {
871        let result = CopilotModelProvider::to_api_content("system", "you are helpful").unwrap();
872        assert!(matches!(result, ApiContent::Text(ref s) if s == "you are helpful"));
873
874        let result = CopilotModelProvider::to_api_content("assistant", "sure").unwrap();
875        assert!(matches!(result, ApiContent::Text(ref s) if s == "sure"));
876    }
877}