Skip to main content

zeroclaw_config/cost/
tracker.rs

1use super::types::{
2    AgentCostStats, BudgetCheck, CostRecord, CostSummary, ModelStats, TokenUsage, UsagePeriod,
3};
4use crate::schema::CostConfig;
5use anyhow::{Context, Result};
6use chrono::{DateTime, Datelike, NaiveDate, Utc};
7use parking_lot::{Mutex, MutexGuard};
8use std::collections::HashMap;
9use std::fs::{self, File, OpenOptions};
10use std::io::{BufRead, BufReader, Write};
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, OnceLock};
13
14/// Cost tracker for API usage monitoring and budget enforcement.
15pub struct CostTracker {
16    config: CostConfig,
17    storage: Arc<Mutex<CostStorage>>,
18    session_id: String,
19    /// Per-daemon-lifetime aggregates keyed by `Option<agent_alias>`,
20    /// replacing the unbounded per-turn `Vec<CostRecord>`.
21    session_totals: Arc<Mutex<HashMap<Option<String>, AgentTotals>>>,
22}
23
24#[derive(Default, Clone, Copy)]
25struct AgentTotals {
26    cost_usd: f64,
27    total_tokens: u64,
28    request_count: u64,
29}
30
31impl CostTracker {
32    /// Create a new cost tracker.
33    pub fn new(config: CostConfig, workspace_dir: &Path) -> Result<Self> {
34        let storage_path = resolve_storage_path(workspace_dir)?;
35
36        let storage = CostStorage::new(&storage_path).with_context(|| {
37            format!(
38                "Failed to open cost storage at {}",
39                storage_path.display().to_string()
40            )
41        })?;
42
43        Ok(Self {
44            config,
45            storage: Arc::new(Mutex::new(storage)),
46            session_id: uuid::Uuid::new_v4().to_string(),
47            session_totals: Arc::new(Mutex::new(HashMap::new())),
48        })
49    }
50
51    /// Get the session ID.
52    pub fn session_id(&self) -> &str {
53        &self.session_id
54    }
55
56    fn lock_storage(&self) -> MutexGuard<'_, CostStorage> {
57        self.storage.lock()
58    }
59
60    fn lock_session_totals(&self) -> MutexGuard<'_, HashMap<Option<String>, AgentTotals>> {
61        self.session_totals.lock()
62    }
63
64    /// Check if a request is within budget.
65    pub fn check_budget(&self, estimated_cost_usd: f64) -> Result<BudgetCheck> {
66        if !self.config.enabled {
67            return Ok(BudgetCheck::Allowed);
68        }
69
70        if !estimated_cost_usd.is_finite() || estimated_cost_usd < 0.0 {
71            ::zeroclaw_log::record!(
72                WARN,
73                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
74                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
75                    .with_attrs(::serde_json::json!({"estimated_cost_usd": estimated_cost_usd})),
76                "cost budget check rejected: estimated cost is not finite or is negative"
77            );
78            anyhow::bail!("Estimated cost must be a finite, non-negative value");
79        }
80
81        let mut storage = self.lock_storage();
82        let (daily_cost, monthly_cost) = storage.get_aggregated_costs()?;
83
84        // Check daily limit
85        let projected_daily = daily_cost + estimated_cost_usd;
86        if projected_daily > self.config.daily_limit_usd {
87            return Ok(BudgetCheck::Exceeded {
88                current_usd: daily_cost,
89                limit_usd: self.config.daily_limit_usd,
90                period: UsagePeriod::Day,
91            });
92        }
93
94        // Check monthly limit
95        let projected_monthly = monthly_cost + estimated_cost_usd;
96        if projected_monthly > self.config.monthly_limit_usd {
97            return Ok(BudgetCheck::Exceeded {
98                current_usd: monthly_cost,
99                limit_usd: self.config.monthly_limit_usd,
100                period: UsagePeriod::Month,
101            });
102        }
103
104        // Check warning thresholds
105        let warn_threshold = f64::from(self.config.warn_at_percent.min(100)) / 100.0;
106        let daily_warn_threshold = self.config.daily_limit_usd * warn_threshold;
107        let monthly_warn_threshold = self.config.monthly_limit_usd * warn_threshold;
108
109        if projected_daily >= daily_warn_threshold {
110            return Ok(BudgetCheck::Warning {
111                current_usd: daily_cost,
112                limit_usd: self.config.daily_limit_usd,
113                period: UsagePeriod::Day,
114            });
115        }
116
117        if projected_monthly >= monthly_warn_threshold {
118            return Ok(BudgetCheck::Warning {
119                current_usd: monthly_cost,
120                limit_usd: self.config.monthly_limit_usd,
121                period: UsagePeriod::Month,
122            });
123        }
124
125        Ok(BudgetCheck::Allowed)
126    }
127
128    /// Record a usage event without per-agent attribution.
129    pub fn record_usage(&self, usage: TokenUsage) -> Result<()> {
130        self.record_usage_with_agent(usage, None)
131    }
132
133    /// Record a usage event attributed to a specific agent alias. When
134    /// `[cost].track_per_agent` is false the alias is dropped before
135    /// persistence.
136    pub fn record_usage_with_agent(
137        &self,
138        usage: TokenUsage,
139        agent_alias: Option<&str>,
140    ) -> Result<()> {
141        if !self.config.enabled {
142            return Ok(());
143        }
144
145        if !usage.cost_usd.is_finite() || usage.cost_usd < 0.0 {
146            ::zeroclaw_log::record!(
147                WARN,
148                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
149                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
150                    .with_attrs(::serde_json::json!({"cost_usd": usage.cost_usd})),
151                "token usage record rejected: cost is not finite or is negative"
152            );
153            anyhow::bail!("Token usage cost must be a finite, non-negative value");
154        }
155
156        let effective_alias = if self.config.track_per_agent {
157            agent_alias.map(str::to_string)
158        } else {
159            None
160        };
161        let cost_usd = usage.cost_usd;
162        let total_tokens = usage.total_tokens;
163        let record = CostRecord::with_agent(&self.session_id, effective_alias.clone(), usage);
164
165        {
166            let mut storage = self.lock_storage();
167            storage.add_record(record)?;
168        }
169
170        {
171            let mut totals = self.lock_session_totals();
172            let entry = totals.entry(effective_alias).or_default();
173            entry.cost_usd += cost_usd;
174            entry.total_tokens += total_tokens;
175            entry.request_count += 1;
176        }
177
178        Ok(())
179    }
180
181    /// Get the current cost summary. When `[cost].track_per_agent` is
182    /// enabled, the response includes a `by_agent` rollup over today's
183    /// records.
184    pub fn get_summary(&self) -> Result<CostSummary> {
185        self.get_summary_filtered(None)
186    }
187
188    /// Filter persisted records by `[from, to)` (either side `None` is
189    /// unbounded) and roll up by_model / by_agent / window totals.
190    /// Bounds come from the caller (the dashboard computes them in the
191    /// operator's local timezone); the tracker doesn't decide what
192    /// "today" means.
193    pub fn get_summary_in_bounds(
194        &self,
195        from: Option<DateTime<Utc>>,
196        to: Option<DateTime<Utc>>,
197    ) -> Result<CostSummary> {
198        let (daily_cost, monthly_cost, records) = {
199            let mut storage = self.lock_storage();
200            let (d, m) = storage.get_aggregated_costs()?;
201            let recs = storage.records_in_bounds(from, to)?;
202            (d, m, recs)
203        };
204        let total_cost: f64 = records.iter().map(|r| r.usage.cost_usd).sum();
205        let total_tokens: u64 = records.iter().map(|r| r.usage.total_tokens).sum();
206        let request_count = records.len();
207        let by_model = build_model_stats(records.iter());
208        let by_agent = if self.config.track_per_agent {
209            build_agent_stats(&records)
210        } else {
211            HashMap::new()
212        };
213        Ok(CostSummary {
214            session_cost_usd: total_cost,
215            daily_cost_usd: daily_cost,
216            monthly_cost_usd: monthly_cost,
217            total_tokens,
218            request_count,
219            by_model,
220            by_agent,
221        })
222    }
223
224    /// Get the current cost summary scoped to a single agent alias. The
225    /// session/day/month figures and `by_model` are filtered to records
226    /// attributed to that alias; `by_agent` is left empty since the
227    /// caller already chose the dimension.
228    pub fn get_summary_for_agent(&self, agent_alias: &str) -> Result<CostSummary> {
229        self.get_summary_filtered(Some(agent_alias))
230    }
231
232    fn get_summary_filtered(&self, agent_filter: Option<&str>) -> Result<CostSummary> {
233        let (daily_cost, monthly_cost, daily_records) = {
234            let mut storage = self.lock_storage();
235            let (d, m) = storage.get_aggregated_costs()?;
236            // Always pull daily_records: per-model and per-agent rollups
237            // both want today's slice. The optional-skip optimisation tied
238            // to `track_per_agent` made the by-model rollup session-scoped,
239            // which surprised operators after a daemon restart and clashes
240            // with the daily totals in the same response.
241            (d, m, storage.daily_records()?)
242        };
243
244        let (session_cost, total_tokens, request_count) = {
245            let totals = self.lock_session_totals();
246            totals
247                .iter()
248                .filter(|(alias, _)| match agent_filter {
249                    Some(want) => alias.as_deref() == Some(want),
250                    None => true,
251                })
252                .fold((0.0_f64, 0_u64, 0_usize), |(c, t, r), (_, v)| {
253                    (
254                        c + v.cost_usd,
255                        t + v.total_tokens,
256                        r + v.request_count as usize,
257                    )
258                })
259        };
260
261        let matches_agent = |record: &CostRecord| match agent_filter {
262            Some(alias) => record.agent_alias.as_deref() == Some(alias),
263            None => true,
264        };
265
266        // Daily-scoped per-model rollup. Filter by agent when scoped.
267        let model_records: Vec<&CostRecord> =
268            daily_records.iter().filter(|r| matches_agent(r)).collect();
269        let by_model = build_model_stats(model_records.iter().copied());
270
271        let (daily_total, monthly_total, by_agent) = if let Some(alias) = agent_filter {
272            // Per-agent view: re-aggregate day/month from persisted records.
273            let mut daily_total = 0.0;
274            let mut monthly_total = 0.0;
275            let today = Utc::now().date_naive();
276            let now = Utc::now();
277            for record in &daily_records {
278                if record.agent_alias.as_deref() != Some(alias) {
279                    continue;
280                }
281                let ts = record.usage.timestamp.naive_utc();
282                if ts.date() == today {
283                    daily_total += record.usage.cost_usd;
284                }
285                if ts.year() == now.year() && ts.month() == now.month() {
286                    monthly_total += record.usage.cost_usd;
287                }
288            }
289            (daily_total, monthly_total, HashMap::new())
290        } else if self.config.track_per_agent {
291            let by_agent = build_agent_stats(&daily_records);
292            (daily_cost, monthly_cost, by_agent)
293        } else {
294            (daily_cost, monthly_cost, HashMap::new())
295        };
296
297        Ok(CostSummary {
298            session_cost_usd: session_cost,
299            daily_cost_usd: daily_total,
300            monthly_cost_usd: monthly_total,
301            total_tokens,
302            request_count,
303            by_model,
304            by_agent,
305        })
306    }
307
308    /// Get the daily cost for a specific date.
309    pub fn get_daily_cost(&self, date: NaiveDate) -> Result<f64> {
310        let storage = self.lock_storage();
311        storage.get_cost_for_date(date)
312    }
313
314    /// Get the monthly cost for a specific month.
315    pub fn get_monthly_cost(&self, year: i32, month: u32) -> Result<f64> {
316        let storage = self.lock_storage();
317        storage.get_cost_for_month(year, month)
318    }
319}
320
321// ── Process-global singleton ────────────────────────────────────────
322// Both the gateway and the channels supervisor share a single CostTracker
323// so that budget enforcement is consistent across all paths.
324
325static GLOBAL_COST_TRACKER: OnceLock<Option<Arc<CostTracker>>> = OnceLock::new();
326
327impl CostTracker {
328    /// Return the process-global `CostTracker`, creating it on first call.
329    /// Subsequent calls (from gateway or channels, whichever starts second)
330    /// receive the same `Arc`.  Returns `None` when cost tracking is disabled
331    /// or initialisation fails.
332    pub fn get_or_init_global(config: CostConfig, workspace_dir: &Path) -> Option<Arc<Self>> {
333        GLOBAL_COST_TRACKER
334            .get_or_init(|| {
335                if !config.enabled {
336                    return None;
337                }
338                match Self::new(config, workspace_dir) {
339                    Ok(ct) => Some(Arc::new(ct)),
340                    Err(e) => {
341                        ::zeroclaw_log::record!(
342                            WARN,
343                            ::zeroclaw_log::Event::new(
344                                module_path!(),
345                                ::zeroclaw_log::Action::Note
346                            )
347                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
348                            .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
349                            "Failed to initialize global cost tracker"
350                        );
351                        None
352                    }
353                }
354            })
355            .clone()
356    }
357}
358
359fn resolve_storage_path(workspace_dir: &Path) -> Result<PathBuf> {
360    let storage_path = workspace_dir.join("state").join("costs.jsonl");
361    let legacy_path = workspace_dir.join(".zeroclaw").join("costs.db");
362
363    if !storage_path.exists() && legacy_path.exists() {
364        if let Some(parent) = storage_path.parent() {
365            fs::create_dir_all(parent).with_context(|| {
366                format!(
367                    "Failed to create directory {}",
368                    parent.display().to_string()
369                )
370            })?;
371        }
372
373        if let Err(error) = fs::rename(&legacy_path, &storage_path) {
374            ::zeroclaw_log::record!(
375                WARN,
376                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
377                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
378                &format!(
379                    "Failed to move legacy cost storage from {} to {}: {error}; falling back to copy",
380                    legacy_path.display().to_string(),
381                    storage_path.display().to_string()
382                )
383            );
384            fs::copy(&legacy_path, &storage_path).with_context(|| {
385                format!(
386                    "Failed to copy legacy cost storage from {} to {}",
387                    legacy_path.display().to_string(),
388                    storage_path.display()
389                )
390            })?;
391        }
392    }
393
394    Ok(storage_path)
395}
396
397fn build_model_stats<'a, I>(records: I) -> HashMap<String, ModelStats>
398where
399    I: IntoIterator<Item = &'a CostRecord>,
400{
401    let mut by_model: HashMap<String, ModelStats> = HashMap::new();
402
403    for record in records {
404        let entry = by_model
405            .entry(record.usage.model.clone())
406            .or_insert_with(|| ModelStats {
407                model: record.usage.model.clone(),
408                cost_usd: 0.0,
409                total_tokens: 0,
410                input_tokens: 0,
411                output_tokens: 0,
412                cached_input_tokens: 0,
413                request_count: 0,
414            });
415
416        entry.cost_usd += record.usage.cost_usd;
417        entry.total_tokens += record.usage.total_tokens;
418        entry.input_tokens += record.usage.input_tokens;
419        entry.output_tokens += record.usage.output_tokens;
420        entry.cached_input_tokens += record.usage.cached_input_tokens;
421        entry.request_count += 1;
422    }
423
424    by_model
425}
426
427fn build_agent_stats(records: &[CostRecord]) -> HashMap<String, AgentCostStats> {
428    let mut by_agent: HashMap<String, AgentCostStats> = HashMap::new();
429
430    for record in records {
431        let Some(alias) = record.agent_alias.as_deref() else {
432            continue;
433        };
434        let entry = by_agent
435            .entry(alias.to_string())
436            .or_insert_with(|| AgentCostStats {
437                agent_alias: alias.to_string(),
438                cost_usd: 0.0,
439                total_tokens: 0,
440                input_tokens: 0,
441                output_tokens: 0,
442                cached_input_tokens: 0,
443                request_count: 0,
444            });
445
446        entry.cost_usd += record.usage.cost_usd;
447        entry.total_tokens += record.usage.total_tokens;
448        entry.input_tokens += record.usage.input_tokens;
449        entry.output_tokens += record.usage.output_tokens;
450        entry.cached_input_tokens += record.usage.cached_input_tokens;
451        entry.request_count += 1;
452    }
453
454    by_agent
455}
456
457/// Persistent storage for cost records.
458struct CostStorage {
459    path: PathBuf,
460    daily_cost_usd: f64,
461    monthly_cost_usd: f64,
462    cached_day: NaiveDate,
463    cached_year: i32,
464    cached_month: u32,
465}
466
467impl CostStorage {
468    /// Create or open cost storage.
469    fn new(path: &Path) -> Result<Self> {
470        if let Some(parent) = path.parent() {
471            fs::create_dir_all(parent).with_context(|| {
472                format!(
473                    "Failed to create directory {}",
474                    parent.display().to_string()
475                )
476            })?;
477        }
478
479        let now = Utc::now();
480        let mut storage = Self {
481            path: path.to_path_buf(),
482            daily_cost_usd: 0.0,
483            monthly_cost_usd: 0.0,
484            cached_day: now.date_naive(),
485            cached_year: now.year(),
486            cached_month: now.month(),
487        };
488
489        storage.rebuild_aggregates(
490            storage.cached_day,
491            storage.cached_year,
492            storage.cached_month,
493        )?;
494
495        Ok(storage)
496    }
497
498    fn for_each_record<F>(&self, mut on_record: F) -> Result<()>
499    where
500        F: FnMut(CostRecord),
501    {
502        if !self.path.exists() {
503            return Ok(());
504        }
505
506        let file = File::open(&self.path).with_context(|| {
507            format!(
508                "Failed to read cost storage from {}",
509                self.path.display().to_string()
510            )
511        })?;
512        let reader = BufReader::new(file);
513
514        for (line_number, line) in reader.lines().enumerate() {
515            let raw_line = line.with_context(|| {
516                format!(
517                    "Failed to read line {} from cost storage {}",
518                    line_number + 1,
519                    self.path.display()
520                )
521            })?;
522
523            let trimmed = raw_line.trim();
524            if trimmed.is_empty() {
525                continue;
526            }
527
528            match serde_json::from_str::<CostRecord>(trimmed) {
529                Ok(record) => on_record(record),
530                Err(error) => {
531                    ::zeroclaw_log::record!(
532                        WARN,
533                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
534                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
535                        &format!(
536                            "Skipping malformed cost record at {}:{}: {error}",
537                            self.path.display().to_string(),
538                            line_number + 1
539                        )
540                    );
541                }
542            }
543        }
544
545        Ok(())
546    }
547
548    fn rebuild_aggregates(&mut self, day: NaiveDate, year: i32, month: u32) -> Result<()> {
549        let mut daily_cost = 0.0;
550        let mut monthly_cost = 0.0;
551
552        self.for_each_record(|record| {
553            let timestamp = record.usage.timestamp.naive_utc();
554
555            if timestamp.date() == day {
556                daily_cost += record.usage.cost_usd;
557            }
558
559            if timestamp.year() == year && timestamp.month() == month {
560                monthly_cost += record.usage.cost_usd;
561            }
562        })?;
563
564        self.daily_cost_usd = daily_cost;
565        self.monthly_cost_usd = monthly_cost;
566        self.cached_day = day;
567        self.cached_year = year;
568        self.cached_month = month;
569
570        Ok(())
571    }
572
573    fn ensure_period_cache_current(&mut self) -> Result<()> {
574        let now = Utc::now();
575        let day = now.date_naive();
576        let year = now.year();
577        let month = now.month();
578
579        if day != self.cached_day || year != self.cached_year || month != self.cached_month {
580            self.rebuild_aggregates(day, year, month)?;
581        }
582
583        Ok(())
584    }
585
586    /// Add a new record.
587    fn add_record(&mut self, record: CostRecord) -> Result<()> {
588        let mut file = OpenOptions::new()
589            .create(true)
590            .append(true)
591            .open(&self.path)
592            .with_context(|| {
593                format!(
594                    "Failed to open cost storage at {}",
595                    self.path.display().to_string()
596                )
597            })?;
598
599        writeln!(file, "{}", serde_json::to_string(&record)?).with_context(|| {
600            format!(
601                "Failed to write cost record to {}",
602                self.path.display().to_string()
603            )
604        })?;
605        file.sync_all().with_context(|| {
606            format!(
607                "Failed to sync cost storage at {}",
608                self.path.display().to_string()
609            )
610        })?;
611
612        self.ensure_period_cache_current()?;
613
614        let timestamp = record.usage.timestamp.naive_utc();
615        if timestamp.date() == self.cached_day {
616            self.daily_cost_usd += record.usage.cost_usd;
617        }
618        if timestamp.year() == self.cached_year && timestamp.month() == self.cached_month {
619            self.monthly_cost_usd += record.usage.cost_usd;
620        }
621
622        Ok(())
623    }
624
625    /// Get aggregated costs for current day and month.
626    fn get_aggregated_costs(&mut self) -> Result<(f64, f64)> {
627        self.ensure_period_cache_current()?;
628        Ok((self.daily_cost_usd, self.monthly_cost_usd))
629    }
630
631    /// Snapshot every record whose timestamp falls within the current
632    /// calendar month. Used to build per-agent rollups without folding a
633    /// new aggregate table into the JSONL file.
634    fn daily_records(&mut self) -> Result<Vec<CostRecord>> {
635        self.ensure_period_cache_current()?;
636        let year = self.cached_year;
637        let month = self.cached_month;
638        let mut out = Vec::new();
639        self.for_each_record(|record| {
640            let ts = record.usage.timestamp.naive_utc();
641            if ts.year() == year && ts.month() == month {
642                out.push(record);
643            }
644        })?;
645        Ok(out)
646    }
647
648    fn records_in_bounds(
649        &mut self,
650        from: Option<DateTime<Utc>>,
651        to: Option<DateTime<Utc>>,
652    ) -> Result<Vec<CostRecord>> {
653        let mut out = Vec::new();
654        self.for_each_record(|record| {
655            let ts = record.usage.timestamp;
656            if from.is_some_and(|f| ts < f) {
657                return;
658            }
659            if to.is_some_and(|t| ts >= t) {
660                return;
661            }
662            out.push(record);
663        })?;
664        Ok(out)
665    }
666
667    /// Get cost for a specific date.
668    fn get_cost_for_date(&self, date: NaiveDate) -> Result<f64> {
669        let mut cost = 0.0;
670
671        self.for_each_record(|record| {
672            if record.usage.timestamp.naive_utc().date() == date {
673                cost += record.usage.cost_usd;
674            }
675        })?;
676
677        Ok(cost)
678    }
679
680    /// Get cost for a specific month.
681    fn get_cost_for_month(&self, year: i32, month: u32) -> Result<f64> {
682        let mut cost = 0.0;
683
684        self.for_each_record(|record| {
685            let timestamp = record.usage.timestamp.naive_utc();
686            if timestamp.year() == year && timestamp.month() == month {
687                cost += record.usage.cost_usd;
688            }
689        })?;
690
691        Ok(cost)
692    }
693}
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698    use tempfile::TempDir;
699
700    fn enabled_config() -> CostConfig {
701        CostConfig {
702            enabled: true,
703            ..Default::default()
704        }
705    }
706
707    #[test]
708    fn cost_tracker_initialization() {
709        let tmp = TempDir::new().unwrap();
710        let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap();
711        assert!(!tracker.session_id().is_empty());
712    }
713
714    #[test]
715    fn budget_check_when_disabled() {
716        let tmp = TempDir::new().unwrap();
717        let config = CostConfig {
718            enabled: false,
719            ..Default::default()
720        };
721
722        let tracker = CostTracker::new(config, tmp.path()).unwrap();
723        let check = tracker.check_budget(1000.0).unwrap();
724        assert!(matches!(check, BudgetCheck::Allowed));
725    }
726
727    #[test]
728    fn record_usage_and_get_summary() {
729        let tmp = TempDir::new().unwrap();
730        let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap();
731
732        let usage = TokenUsage::new("test/model", 1000, 500, 0, 1.0, 2.0, 0.0);
733        tracker.record_usage(usage).unwrap();
734
735        let summary = tracker.get_summary().unwrap();
736        assert_eq!(summary.request_count, 1);
737        assert!(summary.session_cost_usd > 0.0);
738        assert_eq!(summary.by_model.len(), 1);
739    }
740
741    #[test]
742    fn budget_exceeded_daily_limit() {
743        let tmp = TempDir::new().unwrap();
744        let config = CostConfig {
745            enabled: true,
746            daily_limit_usd: 0.01, // Very low limit
747            ..Default::default()
748        };
749
750        let tracker = CostTracker::new(config, tmp.path()).unwrap();
751
752        // Record a usage that exceeds the limit
753        let usage = TokenUsage::new("test/model", 10000, 5000, 0, 1.0, 2.0, 0.0); // ~0.02 USD
754        tracker.record_usage(usage).unwrap();
755
756        let check = tracker.check_budget(0.01).unwrap();
757        assert!(matches!(check, BudgetCheck::Exceeded { .. }));
758    }
759
760    #[test]
761    fn summary_by_model_is_daily_scoped() {
762        // by_model rollup pulls from today's persisted records so the
763        // dashboard's per-model breakdown survives daemon restarts (matches
764        // by_agent's behaviour). A record from another session that
765        // happened today still shows up; only ones outside the day fall
766        // off — exercised by the storage layer's get_aggregated_costs.
767        let tmp = TempDir::new().unwrap();
768        let storage_path = resolve_storage_path(tmp.path()).unwrap();
769        if let Some(parent) = storage_path.parent() {
770            fs::create_dir_all(parent).unwrap();
771        }
772
773        let prior_today = CostRecord::new(
774            "prior-session",
775            TokenUsage::new("prior/model", 500, 500, 0, 1.0, 1.0, 0.0),
776        );
777        let mut file = OpenOptions::new()
778            .create(true)
779            .append(true)
780            .open(storage_path)
781            .unwrap();
782        writeln!(file, "{}", serde_json::to_string(&prior_today).unwrap()).unwrap();
783        file.sync_all().unwrap();
784
785        let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap();
786        tracker
787            .record_usage(TokenUsage::new(
788                "session/model",
789                1000,
790                1000,
791                0,
792                1.0,
793                1.0,
794                0.0,
795            ))
796            .unwrap();
797
798        let summary = tracker.get_summary().unwrap();
799        assert_eq!(
800            summary.by_model.len(),
801            2,
802            "by_model must include every model that recorded today, \
803             regardless of which session wrote the record"
804        );
805        assert!(summary.by_model.contains_key("session/model"));
806        assert!(summary.by_model.contains_key("prior/model"));
807    }
808
809    #[test]
810    fn malformed_lines_are_ignored_while_loading() {
811        let tmp = TempDir::new().unwrap();
812        let storage_path = resolve_storage_path(tmp.path()).unwrap();
813        if let Some(parent) = storage_path.parent() {
814            fs::create_dir_all(parent).unwrap();
815        }
816
817        let valid_usage = TokenUsage::new("test/model", 1000, 0, 0, 1.0, 1.0, 0.0);
818        let valid_record = CostRecord::new("session-a", valid_usage.clone());
819
820        let mut file = OpenOptions::new()
821            .create(true)
822            .append(true)
823            .open(storage_path)
824            .unwrap();
825        writeln!(file, "{}", serde_json::to_string(&valid_record).unwrap()).unwrap();
826        writeln!(file, "not-a-json-line").unwrap();
827        writeln!(file).unwrap();
828        file.sync_all().unwrap();
829
830        let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap();
831        let today_cost = tracker.get_daily_cost(Utc::now().date_naive()).unwrap();
832        assert!((today_cost - valid_usage.cost_usd).abs() < f64::EPSILON);
833    }
834
835    #[test]
836    fn per_agent_aggregation_buckets_by_alias() {
837        let tmp = TempDir::new().unwrap();
838        let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap();
839
840        tracker
841            .record_usage_with_agent(
842                TokenUsage::new("test/model", 1_000, 1_000, 0, 1.0, 1.0, 0.0),
843                Some("scout"),
844            )
845            .unwrap();
846        tracker
847            .record_usage_with_agent(
848                TokenUsage::new("test/model", 2_000, 0, 0, 1.0, 1.0, 0.0),
849                Some("scout"),
850            )
851            .unwrap();
852        tracker
853            .record_usage_with_agent(
854                TokenUsage::new("test/model", 500, 500, 0, 1.0, 1.0, 0.0),
855                Some("scribe"),
856            )
857            .unwrap();
858
859        let summary = tracker.get_summary().unwrap();
860        assert_eq!(summary.by_agent.len(), 2);
861        let scout = summary.by_agent.get("scout").unwrap();
862        assert_eq!(scout.request_count, 2);
863        assert_eq!(scout.total_tokens, 4_000);
864        let scribe = summary.by_agent.get("scribe").unwrap();
865        assert_eq!(scribe.request_count, 1);
866        assert_eq!(scribe.total_tokens, 1_000);
867
868        let scoped = tracker.get_summary_for_agent("scout").unwrap();
869        assert_eq!(scoped.request_count, 2);
870        assert!(
871            scoped.by_agent.is_empty(),
872            "per-agent view doesn't re-bucket"
873        );
874        assert!(
875            (scoped.daily_cost_usd - scout.cost_usd).abs() < 1e-9,
876            "daily filtered to alias must match by_agent bucket"
877        );
878    }
879
880    #[test]
881    fn track_per_agent_disabled_strips_alias() {
882        let tmp = TempDir::new().unwrap();
883        let config = CostConfig {
884            enabled: true,
885            track_per_agent: false,
886            ..Default::default()
887        };
888        let tracker = CostTracker::new(config, tmp.path()).unwrap();
889
890        tracker
891            .record_usage_with_agent(
892                TokenUsage::new("test/model", 1_000, 1_000, 0, 1.0, 1.0, 0.0),
893                Some("scout"),
894            )
895            .unwrap();
896
897        let summary = tracker.get_summary().unwrap();
898        assert_eq!(summary.request_count, 1);
899        assert!(
900            summary.by_agent.is_empty(),
901            "track_per_agent=false must not surface per-agent rollups"
902        );
903    }
904
905    #[test]
906    fn invalid_budget_estimate_is_rejected() {
907        let tmp = TempDir::new().unwrap();
908        let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap();
909
910        let err = tracker.check_budget(f64::NAN).unwrap_err();
911        assert!(
912            err.to_string()
913                .contains("Estimated cost must be a finite, non-negative value")
914        );
915    }
916}