1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct TokenUsage {
6 pub model: String,
8 pub input_tokens: u64,
10 pub output_tokens: u64,
12 #[serde(default, skip_serializing_if = "is_zero_u64")]
17 pub cached_input_tokens: u64,
18 pub total_tokens: u64,
20 pub cost_usd: f64,
22 pub timestamp: chrono::DateTime<chrono::Utc>,
24}
25
26fn is_zero_u64(v: &u64) -> bool {
27 *v == 0
28}
29
30impl TokenUsage {
31 fn sanitize_price(value: f64) -> f64 {
32 if value.is_finite() && value > 0.0 {
33 value
34 } else {
35 0.0
36 }
37 }
38
39 pub fn new(
46 model: impl Into<String>,
47 input_tokens: u64,
48 output_tokens: u64,
49 cached_input_tokens: u64,
50 input_price_per_million: f64,
51 output_price_per_million: f64,
52 cached_input_price_per_million: f64,
53 ) -> Self {
54 let model = model.into();
55 let input_price_per_million = Self::sanitize_price(input_price_per_million);
56 let output_price_per_million = Self::sanitize_price(output_price_per_million);
57 let cached_input_price_per_million = Self::sanitize_price(cached_input_price_per_million);
58 let cached_input_tokens = cached_input_tokens.min(input_tokens);
59 let billable_uncached_input = input_tokens.saturating_sub(cached_input_tokens);
60 let total_tokens = input_tokens.saturating_add(output_tokens);
61
62 let cached_rate = if cached_input_price_per_million > 0.0 {
67 cached_input_price_per_million
68 } else {
69 input_price_per_million
70 };
71 let input_cost = (billable_uncached_input as f64 / 1_000_000.0) * input_price_per_million;
72 let cached_cost = (cached_input_tokens as f64 / 1_000_000.0) * cached_rate;
73 let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price_per_million;
74 let cost_usd = input_cost + cached_cost + output_cost;
75
76 Self {
77 model,
78 input_tokens,
79 output_tokens,
80 cached_input_tokens,
81 total_tokens,
82 cost_usd,
83 timestamp: chrono::Utc::now(),
84 }
85 }
86
87 pub fn cost(&self) -> f64 {
89 self.cost_usd
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95pub enum UsagePeriod {
96 Session,
97 Day,
98 Month,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct CostRecord {
104 pub id: String,
106 pub usage: TokenUsage,
108 pub session_id: String,
110 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub agent_alias: Option<String>,
115}
116
117impl CostRecord {
118 pub fn new(session_id: impl Into<String>, usage: TokenUsage) -> Self {
120 Self {
121 id: uuid::Uuid::new_v4().to_string(),
122 usage,
123 session_id: session_id.into(),
124 agent_alias: None,
125 }
126 }
127
128 pub fn with_agent(
130 session_id: impl Into<String>,
131 agent_alias: Option<String>,
132 usage: TokenUsage,
133 ) -> Self {
134 Self {
135 id: uuid::Uuid::new_v4().to_string(),
136 usage,
137 session_id: session_id.into(),
138 agent_alias,
139 }
140 }
141}
142
143#[derive(Debug, Clone)]
145pub enum BudgetCheck {
146 Allowed,
148 Warning {
150 current_usd: f64,
151 limit_usd: f64,
152 period: UsagePeriod,
153 },
154 Exceeded {
156 current_usd: f64,
157 limit_usd: f64,
158 period: UsagePeriod,
159 },
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct CostSummary {
165 pub session_cost_usd: f64,
167 pub daily_cost_usd: f64,
169 pub monthly_cost_usd: f64,
171 pub total_tokens: u64,
173 pub request_count: usize,
175 pub by_model: std::collections::HashMap<String, ModelStats>,
177 #[serde(default)]
180 pub by_agent: std::collections::HashMap<String, AgentCostStats>,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct AgentCostStats {
186 pub agent_alias: String,
188 pub cost_usd: f64,
190 pub total_tokens: u64,
192 #[serde(default)]
195 pub input_tokens: u64,
196 #[serde(default)]
198 pub output_tokens: u64,
199 #[serde(default)]
202 pub cached_input_tokens: u64,
203 pub request_count: usize,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ModelStats {
210 pub model: String,
212 pub cost_usd: f64,
214 pub total_tokens: u64,
216 #[serde(default)]
218 pub input_tokens: u64,
219 #[serde(default)]
221 pub output_tokens: u64,
222 #[serde(default)]
224 pub cached_input_tokens: u64,
225 pub request_count: usize,
227}
228
229impl Default for CostSummary {
230 fn default() -> Self {
231 Self {
232 session_cost_usd: 0.0,
233 daily_cost_usd: 0.0,
234 monthly_cost_usd: 0.0,
235 total_tokens: 0,
236 request_count: 0,
237 by_model: std::collections::HashMap::new(),
238 by_agent: std::collections::HashMap::new(),
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn token_usage_calculation() {
249 let usage = TokenUsage::new("test/model", 1000, 500, 0, 3.0, 15.0, 0.0);
250
251 assert!((usage.cost_usd - 0.0105).abs() < 0.0001);
253 assert_eq!(usage.input_tokens, 1000);
254 assert_eq!(usage.output_tokens, 500);
255 assert_eq!(usage.total_tokens, 1500);
256 assert_eq!(usage.cached_input_tokens, 0);
257 }
258
259 #[test]
260 fn token_usage_cached_input_billed_at_cached_rate() {
261 let usage = TokenUsage::new("test/model", 1000, 500, 200, 3.0, 15.0, 0.5);
265 assert!((usage.cost_usd - 0.01).abs() < 1e-6, "{}", usage.cost_usd);
266 assert_eq!(usage.cached_input_tokens, 200);
267 }
268
269 #[test]
270 fn token_usage_zero_cached_rate_falls_back_to_input_rate() {
271 let with_cache = TokenUsage::new("test/model", 1000, 500, 200, 3.0, 15.0, 0.0);
275 let without_cache = TokenUsage::new("test/model", 1000, 500, 0, 3.0, 15.0, 0.0);
276 assert!((with_cache.cost_usd - without_cache.cost_usd).abs() < 1e-9);
277 }
278
279 #[test]
280 fn token_usage_zero_tokens() {
281 let usage = TokenUsage::new("test/model", 0, 0, 0, 3.0, 15.0, 0.0);
282 assert!(usage.cost_usd.abs() < f64::EPSILON);
283 assert_eq!(usage.total_tokens, 0);
284 }
285
286 #[test]
287 fn token_usage_negative_or_non_finite_prices_are_clamped() {
288 let usage = TokenUsage::new("test/model", 1000, 1000, 0, -3.0, f64::NAN, f64::INFINITY);
289 assert!(usage.cost_usd.abs() < f64::EPSILON);
290 assert_eq!(usage.total_tokens, 2000);
291 }
292
293 #[test]
294 fn cost_record_creation() {
295 let usage = TokenUsage::new("test/model", 100, 50, 0, 1.0, 2.0, 0.0);
296 let record = CostRecord::new("session-123", usage);
297
298 assert_eq!(record.session_id, "session-123");
299 assert!(!record.id.is_empty());
300 assert_eq!(record.usage.model, "test/model");
301 }
302}