Skip to main content

zeroclaw_providers/
copilot.rs

1//! GitHub Copilot model_provider with OAuth device-flow authentication.
2//!
3//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
4//! then exchanges the OAuth token for short-lived Copilot API keys.
5//! Tokens are cached to disk and auto-refreshed.
6//!
7//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
8//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
9//! and other third-party Copilot integrations. The Copilot token endpoint is
10//! private; there is no public OAuth scope or app registration for it.
11//! GitHub could change or revoke this at any time, which would break all
12//! third-party integrations simultaneously.
13
14use crate::traits::{
15    ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
16    ModelProvider, TokenUsage, ToolCall as ProviderToolCall,
17};
18use async_trait::async_trait;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::sync::Mutex;
25use zeroclaw_api::tool::ToolSpec;
26
27/// GitHub OAuth client ID for Copilot (VS Code extension).
28const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
29const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
30const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
31const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
32const DEFAULT_API: &str = "https://api.githubcopilot.com";
33
34// ── Token types ──────────────────────────────────────────────────
35
36#[derive(Debug, Deserialize)]
37struct DeviceCodeResponse {
38    device_code: String,
39    user_code: String,
40    verification_uri: String,
41    #[serde(default = "default_interval")]
42    interval: u64,
43    #[serde(default = "default_expires_in")]
44    expires_in: u64,
45}
46
47fn default_interval() -> u64 {
48    5
49}
50
51fn default_expires_in() -> u64 {
52    900
53}
54
55#[derive(Debug, Deserialize)]
56struct AccessTokenResponse {
57    access_token: Option<String>,
58    error: Option<String>,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62struct ApiKeyInfo {
63    token: String,
64    expires_at: i64,
65    #[serde(default)]
66    endpoints: Option<ApiEndpoints>,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70struct ApiEndpoints {
71    api: Option<String>,
72}
73
74struct CachedApiKey {
75    token: String,
76    api_endpoint: String,
77    expires_at: i64,
78}
79
80// ── Chat completions types ───────────────────────────────────────
81
82#[derive(Debug, Serialize)]
83struct ApiChatRequest<'a> {
84    model: String,
85    messages: Vec<ApiMessage>,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    temperature: Option<f64>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    tools: Option<Vec<NativeToolSpec<'a>>>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    tool_choice: Option<String>,
92}
93
94#[derive(Debug, Serialize)]
95struct ApiMessage {
96    role: String,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    content: Option<ApiContent>,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    tool_call_id: Option<String>,
101    #[serde(skip_serializing_if = "Option::is_none")]
102    tool_calls: Option<Vec<NativeToolCall>>,
103}
104
105#[derive(Debug, Serialize)]
106struct NativeToolSpec<'a> {
107    #[serde(rename = "type")]
108    kind: &'static str,
109    function: NativeToolFunctionSpec<'a>,
110}
111
112#[derive(Debug, Serialize)]
113struct NativeToolFunctionSpec<'a> {
114    name: &'a str,
115    description: &'a str,
116    parameters: &'a serde_json::Value,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct NativeToolCall {
121    #[serde(skip_serializing_if = "Option::is_none")]
122    id: Option<String>,
123    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
124    kind: Option<String>,
125    function: NativeFunctionCall,
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129struct NativeFunctionCall {
130    name: String,
131    arguments: String,
132}
133
134/// Multi-part content for vision messages (OpenAI format).
135#[derive(Debug, Clone, Serialize)]
136#[serde(untagged)]
137enum ApiContent {
138    Text(String),
139    Parts(Vec<ContentPart>),
140}
141
142#[derive(Debug, Clone, Serialize)]
143#[serde(tag = "type")]
144enum ContentPart {
145    #[serde(rename = "text")]
146    Text { text: String },
147    #[serde(rename = "image_url")]
148    ImageUrl { image_url: ImageUrlDetail },
149}
150
151#[derive(Debug, Clone, Serialize)]
152struct ImageUrlDetail {
153    url: String,
154}
155
156#[derive(Debug, Deserialize)]
157struct ApiChatResponse {
158    choices: Vec<Choice>,
159    #[serde(default)]
160    usage: Option<UsageInfo>,
161}
162
163#[derive(Debug, Deserialize)]
164struct UsageInfo {
165    #[serde(default)]
166    prompt_tokens: Option<u64>,
167    #[serde(default)]
168    completion_tokens: Option<u64>,
169}
170
171#[derive(Debug, Deserialize)]
172struct Choice {
173    message: ResponseMessage,
174}
175
176#[derive(Debug, Deserialize)]
177struct ResponseMessage {
178    #[serde(default)]
179    content: Option<String>,
180    #[serde(default)]
181    tool_calls: Option<Vec<NativeToolCall>>,
182}
183
184// ── ModelProvider ─────────────────────────────────────────────────────
185
186/// GitHub Copilot model_provider with automatic OAuth and token refresh.
187///
188/// On first use, prompts the user to visit github.com/login/device.
189/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
190/// automatically.
191pub struct CopilotModelProvider {
192    /// `[providers.models.<family>.<alias>]` config-key alias.
193    alias: String,
194    github_token: Option<String>,
195    /// Mutex ensures only one caller refreshes tokens at a time,
196    /// preventing duplicate device flow prompts or redundant API calls.
197    refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
198    token_dir: PathBuf,
199}
200
201impl CopilotModelProvider {
202    pub fn new(alias: &str, github_token: Option<&str>) -> Self {
203        let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
204            .map(|dir| dir.config_dir().join("copilot"))
205            .unwrap_or_else(|| {
206                // Fall back to a user-specific temp directory to avoid
207                // shared-directory symlink attacks.
208                let user = std::env::var("USER")
209                    .or_else(|_| std::env::var("USERNAME"))
210                    .unwrap_or_else(|_| "unknown".to_string());
211                std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
212            });
213
214        if let Err(err) = std::fs::create_dir_all(&token_dir) {
215            ::zeroclaw_log::record!(
216                WARN,
217                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
218                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
219                &format!(
220                    "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
221                    token_dir
222                )
223            );
224        } else {
225            #[cfg(unix)]
226            {
227                use std::os::unix::fs::PermissionsExt;
228
229                if let Err(err) =
230                    std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
231                {
232                    ::zeroclaw_log::record!(
233                        WARN,
234                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
235                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
236                        &format!(
237                            "Failed to set Copilot token directory permissions on {:?}: {err}",
238                            token_dir
239                        )
240                    );
241                }
242            }
243        }
244
245        Self {
246            alias: alias.to_string(),
247            github_token: github_token
248                .filter(|token| !token.is_empty())
249                .map(String::from),
250            refresh_lock: Arc::new(Mutex::new(None)),
251            token_dir,
252        }
253    }
254    fn http_client(&self) -> Client {
255        zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
256            "model_provider.copilot",
257            120,
258            10,
259        )
260    }
261
262    /// Required headers for Copilot API requests (editor identification).
263    const COPILOT_HEADERS: [(&str, &str); 4] = [
264        ("Editor-Version", "vscode/1.85.1"),
265        ("Editor-Plugin-Version", "copilot/1.155.0"),
266        ("User-Agent", "GithubCopilot/1.155.0"),
267        ("Accept", "application/json"),
268    ];
269
270    fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec<'_>>> {
271        tools.map(|items| {
272            items
273                .iter()
274                .map(|tool| NativeToolSpec {
275                    kind: "function",
276                    function: NativeToolFunctionSpec {
277                        name: &tool.name,
278                        description: &tool.description,
279                        parameters: &tool.parameters,
280                    },
281                })
282                .collect()
283        })
284    }
285
286    /// Convert message content to API format, with multi-part support for
287    /// user messages containing `[IMAGE:...]` markers.
288    fn to_api_content(role: &str, content: &str) -> Option<ApiContent> {
289        if role != "user" {
290            return Some(ApiContent::Text(content.to_string()));
291        }
292
293        let (cleaned_text, image_refs) = crate::multimodal::parse_image_markers(content);
294        if image_refs.is_empty() {
295            return Some(ApiContent::Text(content.to_string()));
296        }
297
298        let mut parts = Vec::with_capacity(image_refs.len() + 1);
299        let trimmed = cleaned_text.trim();
300        if !trimmed.is_empty() {
301            parts.push(ContentPart::Text {
302                text: trimmed.to_string(),
303            });
304        }
305        for image_ref in image_refs {
306            parts.push(ContentPart::ImageUrl {
307                image_url: ImageUrlDetail { url: image_ref },
308            });
309        }
310
311        Some(ApiContent::Parts(parts))
312    }
313
314    fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
315        messages
316            .iter()
317            .map(|message| {
318                if message.role == "assistant"
319                    && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
320                    && let Some(tool_calls_value) = value.get("tool_calls")
321                    && let Ok(parsed_calls) =
322                        serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
323                {
324                    let tool_calls = parsed_calls
325                        .into_iter()
326                        .map(|tool_call| NativeToolCall {
327                            id: Some(tool_call.id),
328                            kind: Some("function".to_string()),
329                            function: NativeFunctionCall {
330                                name: tool_call.name,
331                                arguments: tool_call.arguments,
332                            },
333                        })
334                        .collect::<Vec<_>>();
335
336                    let content = value
337                        .get("content")
338                        .and_then(serde_json::Value::as_str)
339                        .map(|s| ApiContent::Text(s.to_string()));
340
341                    return ApiMessage {
342                        role: "assistant".to_string(),
343                        content,
344                        tool_call_id: None,
345                        tool_calls: Some(tool_calls),
346                    };
347                }
348
349                if message.role == "tool"
350                    && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
351                {
352                    let tool_call_id = value
353                        .get("tool_call_id")
354                        .and_then(serde_json::Value::as_str)
355                        .map(ToString::to_string);
356                    let content = value
357                        .get("content")
358                        .and_then(serde_json::Value::as_str)
359                        .map(|s| ApiContent::Text(s.to_string()));
360
361                    return ApiMessage {
362                        role: "tool".to_string(),
363                        content,
364                        tool_call_id,
365                        tool_calls: None,
366                    };
367                }
368
369                ApiMessage {
370                    role: message.role.clone(),
371                    content: Self::to_api_content(&message.role, &message.content),
372                    tool_call_id: None,
373                    tool_calls: None,
374                }
375            })
376            .collect()
377    }
378
379    /// Send a chat completions request with required Copilot headers.
380    async fn send_chat_request(
381        &self,
382        messages: Vec<ApiMessage>,
383        tools: Option<&[ToolSpec]>,
384        model: &str,
385        temperature: Option<f64>,
386    ) -> anyhow::Result<ProviderChatResponse> {
387        let (token, endpoint) = self.get_api_key().await?;
388        let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
389
390        let native_tools = Self::convert_tools(tools);
391        let request = ApiChatRequest {
392            model: model.to_string(),
393            messages,
394            temperature,
395            tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
396            tools: native_tools,
397        };
398
399        let mut req = self
400            .http_client()
401            .post(&url)
402            .header("Authorization", format!("Bearer {token}"))
403            .json(&request);
404
405        for (header, value) in &Self::COPILOT_HEADERS {
406            req = req.header(*header, *value);
407        }
408
409        let response = req.send().await?;
410
411        if !response.status().is_success() {
412            return Err(super::api_error("GitHub Copilot", response).await);
413        }
414
415        let api_response: ApiChatResponse = response.json().await?;
416        let usage = api_response.usage.map(|u| TokenUsage {
417            input_tokens: u.prompt_tokens,
418            output_tokens: u.completion_tokens,
419            cached_input_tokens: None,
420        });
421        let choice = api_response.choices.into_iter().next().ok_or_else(|| {
422            ::zeroclaw_log::record!(
423                ERROR,
424                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
425                    .with_outcome(::zeroclaw_log::EventOutcome::Failure),
426                "copilot: empty choices in response"
427            );
428            anyhow::Error::msg("No response from GitHub Copilot")
429        })?;
430
431        let tool_calls = choice
432            .message
433            .tool_calls
434            .unwrap_or_default()
435            .into_iter()
436            .map(|tool_call| ProviderToolCall {
437                id: tool_call
438                    .id
439                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
440                name: tool_call.function.name,
441                arguments: tool_call.function.arguments,
442                extra_content: None,
443            })
444            .collect();
445
446        Ok(ProviderChatResponse {
447            text: choice.message.content,
448            tool_calls,
449            usage,
450            reasoning_content: None,
451        })
452    }
453
454    /// Get a valid Copilot API key, refreshing or re-authenticating as needed.
455    /// Uses a Mutex to ensure only one caller refreshes at a time.
456    async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
457        let mut cached = self.refresh_lock.lock().await;
458
459        if let Some(cached_key) = cached.as_ref()
460            && chrono::Utc::now().timestamp() + 120 < cached_key.expires_at
461        {
462            return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
463        }
464
465        if let Some(info) = self.load_api_key_from_disk().await
466            && chrono::Utc::now().timestamp() + 120 < info.expires_at
467        {
468            let endpoint = info
469                .endpoints
470                .as_ref()
471                .and_then(|e| e.api.clone())
472                .unwrap_or_else(|| DEFAULT_API.to_string());
473            let token = info.token;
474
475            *cached = Some(CachedApiKey {
476                token: token.clone(),
477                api_endpoint: endpoint.clone(),
478                expires_at: info.expires_at,
479            });
480            return Ok((token, endpoint));
481        }
482
483        let access_token = self.get_github_access_token().await?;
484        let api_key_info = self.exchange_for_api_key(&access_token).await?;
485        self.save_api_key_to_disk(&api_key_info).await;
486
487        let endpoint = api_key_info
488            .endpoints
489            .as_ref()
490            .and_then(|e| e.api.clone())
491            .unwrap_or_else(|| DEFAULT_API.to_string());
492
493        *cached = Some(CachedApiKey {
494            token: api_key_info.token.clone(),
495            api_endpoint: endpoint.clone(),
496            expires_at: api_key_info.expires_at,
497        });
498
499        Ok((api_key_info.token, endpoint))
500    }
501
502    /// Get a GitHub access token from config, cache, or device flow.
503    async fn get_github_access_token(&self) -> anyhow::Result<String> {
504        if let Some(token) = &self.github_token {
505            return Ok(token.clone());
506        }
507
508        let access_token_path = self.token_dir.join("access-token");
509        if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
510            let token = cached.trim();
511            if !token.is_empty() {
512                return Ok(token.to_string());
513            }
514        }
515
516        let token = self.device_code_login().await?;
517        write_file_secure(&access_token_path, &token).await;
518        Ok(token)
519    }
520
521    /// Run GitHub OAuth device code flow.
522    async fn device_code_login(&self) -> anyhow::Result<String> {
523        let response: DeviceCodeResponse = self
524            .http_client()
525            .post(GITHUB_DEVICE_CODE_URL)
526            .header("Accept", "application/json")
527            .json(&serde_json::json!({
528                "client_id": GITHUB_CLIENT_ID,
529                "scope": "read:user"
530            }))
531            .send()
532            .await?
533            .error_for_status()?
534            .json()
535            .await?;
536
537        let mut poll_interval = Duration::from_secs(response.interval.max(5));
538        let expires_in = response.expires_in.max(1);
539        let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
540
541        eprintln!(
542            "\nGitHub Copilot authentication is required.\n\
543             Visit: {}\n\
544             Code: {}\n\
545             Waiting for authorization...\n",
546            response.verification_uri, response.user_code
547        );
548
549        while tokio::time::Instant::now() < expires_at {
550            tokio::time::sleep(poll_interval).await;
551
552            let token_response: AccessTokenResponse = self
553                .http_client()
554                .post(GITHUB_ACCESS_TOKEN_URL)
555                .header("Accept", "application/json")
556                .json(&serde_json::json!({
557                    "client_id": GITHUB_CLIENT_ID,
558                    "device_code": response.device_code,
559                    "grant_type": "urn:ietf:params:oauth:grant-type:device_code"
560                }))
561                .send()
562                .await?
563                .json()
564                .await?;
565
566            if let Some(token) = token_response.access_token {
567                eprintln!("Authentication succeeded.\n");
568                return Ok(token);
569            }
570
571            match token_response.error.as_deref() {
572                Some("slow_down") => {
573                    poll_interval += Duration::from_secs(5);
574                }
575                Some("authorization_pending") | None => {}
576                Some("expired_token") => {
577                    anyhow::bail!("GitHub device authorization expired")
578                }
579                Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
580            }
581        }
582
583        anyhow::bail!("Timed out waiting for GitHub authorization")
584    }
585
586    /// Exchange a GitHub access token for a Copilot API key.
587    async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
588        let mut request = self.http_client().get(GITHUB_API_KEY_URL);
589        for (header, value) in &Self::COPILOT_HEADERS {
590            request = request.header(*header, *value);
591        }
592        request = request.header("Authorization", format!("token {access_token}"));
593
594        let response = request.send().await?;
595
596        if !response.status().is_success() {
597            let status = response.status();
598            let body = response.text().await.unwrap_or_default();
599            let sanitized = super::sanitize_api_error(&body);
600
601            if status.as_u16() == 401 || status.as_u16() == 403 {
602                let access_token_path = self.token_dir.join("access-token");
603                tokio::fs::remove_file(&access_token_path).await.ok();
604            }
605
606            anyhow::bail!(
607                "Failed to get Copilot API key ({status}): {sanitized}. \
608                 Ensure your GitHub account has an active Copilot subscription."
609            );
610        }
611
612        let info: ApiKeyInfo = response.json().await?;
613        Ok(info)
614    }
615
616    async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
617        let path = self.token_dir.join("api-key.json");
618        let data = tokio::fs::read_to_string(&path).await.ok()?;
619        serde_json::from_str(&data).ok()
620    }
621
622    async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
623        let path = self.token_dir.join("api-key.json");
624        if let Ok(json) = serde_json::to_string_pretty(info) {
625            write_file_secure(&path, &json).await;
626        }
627    }
628}
629
630/// Write a file with 0600 permissions (owner read/write only).
631/// Uses `spawn_blocking` to avoid blocking the async runtime.
632async fn write_file_secure(path: &Path, content: &str) {
633    let path = path.to_path_buf();
634    let content = content.to_string();
635
636    let result = tokio::task::spawn_blocking(move || {
637        #[cfg(unix)]
638        {
639            use std::io::Write;
640            use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
641
642            let mut file = std::fs::OpenOptions::new()
643                .write(true)
644                .create(true)
645                .truncate(true)
646                .mode(0o600)
647                .open(&path)?;
648            file.write_all(content.as_bytes())?;
649
650            std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
651            Ok::<(), std::io::Error>(())
652        }
653        #[cfg(not(unix))]
654        {
655            std::fs::write(&path, &content)?;
656            Ok::<(), std::io::Error>(())
657        }
658    })
659    .await;
660
661    match result {
662        Ok(Ok(())) => {}
663        Ok(Err(err)) => ::zeroclaw_log::record!(
664            WARN,
665            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
666                .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
667                .with_attrs(::serde_json::json!({"error": format!("{}", err)})),
668            "Failed to write secure file"
669        ),
670        Err(err) => ::zeroclaw_log::record!(
671            WARN,
672            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
673                .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
674                .with_attrs(::serde_json::json!({"error": format!("{}", err)})),
675            "Failed to spawn blocking write"
676        ),
677    }
678}
679
680#[async_trait]
681impl ModelProvider for CopilotModelProvider {
682    // ── ModelProvider-family defaults ──
683    fn default_base_url(&self) -> Option<&str> {
684        Some(DEFAULT_API)
685    }
686
687    async fn chat_with_system(
688        &self,
689        system_prompt: Option<&str>,
690        message: &str,
691        model: &str,
692        temperature: Option<f64>,
693    ) -> anyhow::Result<String> {
694        let mut messages = Vec::new();
695        if let Some(system) = system_prompt {
696            messages.push(ApiMessage {
697                role: "system".to_string(),
698                content: Some(ApiContent::Text(system.to_string())),
699                tool_call_id: None,
700                tool_calls: None,
701            });
702        }
703        messages.push(ApiMessage {
704            role: "user".to_string(),
705            content: Self::to_api_content("user", message),
706            tool_call_id: None,
707            tool_calls: None,
708        });
709
710        let response = self
711            .send_chat_request(messages, None, model, temperature)
712            .await?;
713        Ok(response.text.unwrap_or_default())
714    }
715
716    async fn chat_with_history(
717        &self,
718        messages: &[ChatMessage],
719        model: &str,
720        temperature: Option<f64>,
721    ) -> anyhow::Result<String> {
722        let response = self
723            .send_chat_request(Self::convert_messages(messages), None, model, temperature)
724            .await?;
725        Ok(response.text.unwrap_or_default())
726    }
727
728    async fn chat(
729        &self,
730        request: ProviderChatRequest<'_>,
731        model: &str,
732        temperature: Option<f64>,
733    ) -> anyhow::Result<ProviderChatResponse> {
734        self.send_chat_request(
735            Self::convert_messages(request.messages),
736            request.tools,
737            model,
738            temperature,
739        )
740        .await
741    }
742
743    fn supports_native_tools(&self) -> bool {
744        true
745    }
746
747    async fn warmup(&self) -> anyhow::Result<()> {
748        let _ = self.get_api_key().await?;
749        Ok(())
750    }
751}
752
753impl ::zeroclaw_api::attribution::Attributable for CopilotModelProvider {
754    fn role(&self) -> ::zeroclaw_api::attribution::Role {
755        ::zeroclaw_api::attribution::Role::Provider(
756            ::zeroclaw_api::attribution::ProviderKind::Model(
757                ::zeroclaw_api::attribution::ModelProviderKind::Copilot,
758            ),
759        )
760    }
761    fn alias(&self) -> &str {
762        &self.alias
763    }
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769
770    #[test]
771    fn new_without_token() {
772        let model_provider = CopilotModelProvider::new("test", None);
773        assert!(model_provider.github_token.is_none());
774    }
775
776    #[test]
777    fn new_with_token() {
778        let model_provider = CopilotModelProvider::new("test", Some("ghp_test"));
779        assert_eq!(model_provider.github_token.as_deref(), Some("ghp_test"));
780    }
781
782    #[test]
783    fn empty_token_treated_as_none() {
784        let model_provider = CopilotModelProvider::new("test", Some(""));
785        assert!(model_provider.github_token.is_none());
786    }
787
788    #[tokio::test]
789    async fn cache_starts_empty() {
790        let model_provider = CopilotModelProvider::new("test", None);
791        let cached = model_provider.refresh_lock.lock().await;
792        assert!(cached.is_none());
793    }
794
795    #[test]
796    fn copilot_headers_include_required_fields() {
797        let headers = CopilotModelProvider::COPILOT_HEADERS;
798        assert!(
799            headers
800                .iter()
801                .any(|(header, _)| *header == "Editor-Version")
802        );
803        assert!(
804            headers
805                .iter()
806                .any(|(header, _)| *header == "Editor-Plugin-Version")
807        );
808        assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
809    }
810
811    #[test]
812    fn default_interval_and_expiry() {
813        assert_eq!(default_interval(), 5);
814        assert_eq!(default_expires_in(), 900);
815    }
816
817    #[test]
818    fn supports_native_tools() {
819        let model_provider = CopilotModelProvider::new("test", None);
820        assert!(model_provider.supports_native_tools());
821    }
822
823    #[test]
824    fn api_response_parses_usage() {
825        let json = r#"{
826            "choices": [{"message": {"content": "Hello"}}],
827            "usage": {"prompt_tokens": 200, "completion_tokens": 80}
828        }"#;
829        let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
830        let usage = resp.usage.unwrap();
831        assert_eq!(usage.prompt_tokens, Some(200));
832        assert_eq!(usage.completion_tokens, Some(80));
833    }
834
835    #[test]
836    fn api_response_parses_without_usage() {
837        let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#;
838        let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
839        assert!(resp.usage.is_none());
840    }
841
842    #[test]
843    fn to_api_content_user_with_image_returns_parts() {
844        let content = "describe this [IMAGE:data:image/png;base64,abc123]";
845        let result = CopilotModelProvider::to_api_content("user", content).unwrap();
846        match result {
847            ApiContent::Parts(parts) => {
848                assert_eq!(parts.len(), 2);
849                assert!(matches!(&parts[0], ContentPart::Text { text } if text == "describe this"));
850                assert!(
851                    matches!(&parts[1], ContentPart::ImageUrl { image_url } if image_url.url == "data:image/png;base64,abc123")
852                );
853            }
854            ApiContent::Text(_) => {
855                panic!("expected ApiContent::Parts for user message with image marker")
856            }
857        }
858    }
859
860    #[test]
861    fn to_api_content_user_plain_returns_text() {
862        let result = CopilotModelProvider::to_api_content("user", "hello world").unwrap();
863        assert!(matches!(result, ApiContent::Text(ref s) if s == "hello world"));
864    }
865
866    #[test]
867    fn to_api_content_non_user_returns_text() {
868        let result = CopilotModelProvider::to_api_content("system", "you are helpful").unwrap();
869        assert!(matches!(result, ApiContent::Text(ref s) if s == "you are helpful"));
870
871        let result = CopilotModelProvider::to_api_content("assistant", "sure").unwrap();
872        assert!(matches!(result, ApiContent::Text(ref s) if s == "sure"));
873    }
874}