Skip to main content

zeroclaw_config/cost/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// Token usage information from a single API call.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct TokenUsage {
6    /// Model identifier (e.g., "anthropic/claude-sonnet-4-20250514")
7    pub model: String,
8    /// Input/prompt tokens
9    pub input_tokens: u64,
10    /// Output/completion tokens
11    pub output_tokens: u64,
12    /// Cached input tokens (Anthropic `cache_read_input_tokens`, OpenAI
13    /// `prompt_tokens_details.cached_tokens`). Subset of `input_tokens`
14    /// when reported by the provider; the rate sheet's
15    /// `cached_input_per_mtok` applies to these.
16    #[serde(default, skip_serializing_if = "is_zero_u64")]
17    pub cached_input_tokens: u64,
18    /// Total tokens (input + output, ignoring the cached subset).
19    pub total_tokens: u64,
20    /// Calculated cost in USD
21    pub cost_usd: f64,
22    /// Timestamp of the request
23    pub timestamp: chrono::DateTime<chrono::Utc>,
24}
25
26fn is_zero_u64(v: &u64) -> bool {
27    *v == 0
28}
29
30impl TokenUsage {
31    fn sanitize_price(value: f64) -> f64 {
32        if value.is_finite() && value > 0.0 {
33            value
34        } else {
35            0.0
36        }
37    }
38
39    /// Create a new token usage record. Cached input tokens are billed at
40    /// `cached_input_price_per_million`; the rest of `input_tokens` at the
41    /// standard `input_price_per_million`. When `cached_input_price` is 0
42    /// the cached subset bills at the standard rate (no discount), so
43    /// providers that don't surface a cached rate still produce a sane
44    /// total.
45    pub fn new(
46        model: impl Into<String>,
47        input_tokens: u64,
48        output_tokens: u64,
49        cached_input_tokens: u64,
50        input_price_per_million: f64,
51        output_price_per_million: f64,
52        cached_input_price_per_million: f64,
53    ) -> Self {
54        let model = model.into();
55        let input_price_per_million = Self::sanitize_price(input_price_per_million);
56        let output_price_per_million = Self::sanitize_price(output_price_per_million);
57        let cached_input_price_per_million = Self::sanitize_price(cached_input_price_per_million);
58        let cached_input_tokens = cached_input_tokens.min(input_tokens);
59        let billable_uncached_input = input_tokens.saturating_sub(cached_input_tokens);
60        let total_tokens = input_tokens.saturating_add(output_tokens);
61
62        // Calculate cost: (tokens / 1M) * price_per_million for each band.
63        // Cached subset uses its own rate when set, else falls back to the
64        // standard input rate so providers without a cache-rate aren't
65        // charged $0 for the cached portion.
66        let cached_rate = if cached_input_price_per_million > 0.0 {
67            cached_input_price_per_million
68        } else {
69            input_price_per_million
70        };
71        let input_cost = (billable_uncached_input as f64 / 1_000_000.0) * input_price_per_million;
72        let cached_cost = (cached_input_tokens as f64 / 1_000_000.0) * cached_rate;
73        let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price_per_million;
74        let cost_usd = input_cost + cached_cost + output_cost;
75
76        Self {
77            model,
78            input_tokens,
79            output_tokens,
80            cached_input_tokens,
81            total_tokens,
82            cost_usd,
83            timestamp: chrono::Utc::now(),
84        }
85    }
86
87    /// Get the total cost.
88    pub fn cost(&self) -> f64 {
89        self.cost_usd
90    }
91}
92
93/// Time period for cost aggregation.
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95pub enum UsagePeriod {
96    Session,
97    Day,
98    Month,
99}
100
101/// A single cost record for persistent storage.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct CostRecord {
104    /// Unique identifier
105    pub id: String,
106    /// Token usage details
107    pub usage: TokenUsage,
108    /// Session identifier (for grouping)
109    pub session_id: String,
110    /// Alias of the agent that incurred this cost (HashMap key in
111    /// `config.agents`). `None` for records persisted before per-agent
112    /// attribution, or when `[cost].track_per_agent = false`.
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub agent_alias: Option<String>,
115}
116
117impl CostRecord {
118    /// Create a new cost record without agent attribution.
119    pub fn new(session_id: impl Into<String>, usage: TokenUsage) -> Self {
120        Self {
121            id: uuid::Uuid::new_v4().to_string(),
122            usage,
123            session_id: session_id.into(),
124            agent_alias: None,
125        }
126    }
127
128    /// Create a new cost record attributed to an agent.
129    pub fn with_agent(
130        session_id: impl Into<String>,
131        agent_alias: Option<String>,
132        usage: TokenUsage,
133    ) -> Self {
134        Self {
135            id: uuid::Uuid::new_v4().to_string(),
136            usage,
137            session_id: session_id.into(),
138            agent_alias,
139        }
140    }
141}
142
143/// Budget enforcement result.
144#[derive(Debug, Clone)]
145pub enum BudgetCheck {
146    /// Within budget, request can proceed
147    Allowed,
148    /// Warning threshold exceeded but request can proceed
149    Warning {
150        current_usd: f64,
151        limit_usd: f64,
152        period: UsagePeriod,
153    },
154    /// Budget exceeded, request blocked
155    Exceeded {
156        current_usd: f64,
157        limit_usd: f64,
158        period: UsagePeriod,
159    },
160}
161
162/// Cost summary for reporting.
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct CostSummary {
165    /// Total cost for the session
166    pub session_cost_usd: f64,
167    /// Total cost for the day
168    pub daily_cost_usd: f64,
169    /// Total cost for the month
170    pub monthly_cost_usd: f64,
171    /// Total tokens used
172    pub total_tokens: u64,
173    /// Number of requests
174    pub request_count: usize,
175    /// Breakdown by model
176    pub by_model: std::collections::HashMap<String, ModelStats>,
177    /// Breakdown by agent alias. Empty when `[cost].track_per_agent =
178    /// false` or when no records carry an agent_alias.
179    #[serde(default)]
180    pub by_agent: std::collections::HashMap<String, AgentCostStats>,
181}
182
183/// Statistics for a specific agent alias.
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct AgentCostStats {
186    /// Agent alias (HashMap key in `config.agents`).
187    pub agent_alias: String,
188    /// Total cost attributed to this agent for the period.
189    pub cost_usd: f64,
190    /// Total tokens attributed to this agent for the period (input + output).
191    pub total_tokens: u64,
192    /// Input tokens (uncached + cached). Matches each record's
193    /// `input_tokens` field.
194    #[serde(default)]
195    pub input_tokens: u64,
196    /// Output tokens.
197    #[serde(default)]
198    pub output_tokens: u64,
199    /// Cached input tokens (subset of `input_tokens` served from the
200    /// provider's prompt cache).
201    #[serde(default)]
202    pub cached_input_tokens: u64,
203    /// Number of LLM responses attributed to this agent for the period.
204    pub request_count: usize,
205}
206
207/// Statistics for a specific model.
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ModelStats {
210    /// Model name (upstream resource id from usage telemetry).
211    pub model: String,
212    /// Total cost for this model.
213    pub cost_usd: f64,
214    /// Total tokens for this model (input + output).
215    pub total_tokens: u64,
216    /// Input tokens (uncached + cached).
217    #[serde(default)]
218    pub input_tokens: u64,
219    /// Output tokens.
220    #[serde(default)]
221    pub output_tokens: u64,
222    /// Cached input tokens served from the prompt cache.
223    #[serde(default)]
224    pub cached_input_tokens: u64,
225    /// Number of LLM responses for this model.
226    pub request_count: usize,
227}
228
229impl Default for CostSummary {
230    fn default() -> Self {
231        Self {
232            session_cost_usd: 0.0,
233            daily_cost_usd: 0.0,
234            monthly_cost_usd: 0.0,
235            total_tokens: 0,
236            request_count: 0,
237            by_model: std::collections::HashMap::new(),
238            by_agent: std::collections::HashMap::new(),
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn token_usage_calculation() {
249        let usage = TokenUsage::new("test/model", 1000, 500, 0, 3.0, 15.0, 0.0);
250
251        // Expected: (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105
252        assert!((usage.cost_usd - 0.0105).abs() < 0.0001);
253        assert_eq!(usage.input_tokens, 1000);
254        assert_eq!(usage.output_tokens, 500);
255        assert_eq!(usage.total_tokens, 1500);
256        assert_eq!(usage.cached_input_tokens, 0);
257    }
258
259    #[test]
260    fn token_usage_cached_input_billed_at_cached_rate() {
261        // 200 cached input @ 0.5/Mtok + 800 uncached input @ 3/Mtok + 500 output @ 15/Mtok
262        // = (200/1e6)*0.5 + (800/1e6)*3 + (500/1e6)*15
263        // = 0.0001 + 0.0024 + 0.0075 = 0.01
264        let usage = TokenUsage::new("test/model", 1000, 500, 200, 3.0, 15.0, 0.5);
265        assert!((usage.cost_usd - 0.01).abs() < 1e-6, "{}", usage.cost_usd);
266        assert_eq!(usage.cached_input_tokens, 200);
267    }
268
269    #[test]
270    fn token_usage_zero_cached_rate_falls_back_to_input_rate() {
271        // Cached rate 0 means "no discount" — bill cached subset at the
272        // standard input rate so providers without a published cache rate
273        // still produce a sane total.
274        let with_cache = TokenUsage::new("test/model", 1000, 500, 200, 3.0, 15.0, 0.0);
275        let without_cache = TokenUsage::new("test/model", 1000, 500, 0, 3.0, 15.0, 0.0);
276        assert!((with_cache.cost_usd - without_cache.cost_usd).abs() < 1e-9);
277    }
278
279    #[test]
280    fn token_usage_zero_tokens() {
281        let usage = TokenUsage::new("test/model", 0, 0, 0, 3.0, 15.0, 0.0);
282        assert!(usage.cost_usd.abs() < f64::EPSILON);
283        assert_eq!(usage.total_tokens, 0);
284    }
285
286    #[test]
287    fn token_usage_negative_or_non_finite_prices_are_clamped() {
288        let usage = TokenUsage::new("test/model", 1000, 1000, 0, -3.0, f64::NAN, f64::INFINITY);
289        assert!(usage.cost_usd.abs() < f64::EPSILON);
290        assert_eq!(usage.total_tokens, 2000);
291    }
292
293    #[test]
294    fn cost_record_creation() {
295        let usage = TokenUsage::new("test/model", 100, 50, 0, 1.0, 2.0, 0.0);
296        let record = CostRecord::new("session-123", usage);
297
298        assert_eq!(record.session_id, "session-123");
299        assert!(!record.id.is_empty());
300        assert_eq!(record.usage.model, "test/model");
301    }
302}