1use crate::cost::CostTracker;
2use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
3use parking_lot::Mutex;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, OnceLock};
6
7pub type ModelProviderPricing = HashMap<String, HashMap<String, f64>>;
16
17#[derive(Default, Clone, Copy, Debug)]
22pub struct TurnUsage {
23 pub input_tokens: u64,
24 pub output_tokens: u64,
25 pub cost_usd: f64,
26}
27
28#[derive(Clone)]
31pub struct ToolLoopCostTrackingContext {
32 pub tracker: Arc<CostTracker>,
33 pub model_provider_pricing: Arc<ModelProviderPricing>,
34 pub turn_usage: Arc<Mutex<TurnUsage>>,
35 pub agent_alias: Option<String>,
38}
39
40impl ToolLoopCostTrackingContext {
41 pub fn new(
42 tracker: Arc<CostTracker>,
43 model_provider_pricing: Arc<ModelProviderPricing>,
44 ) -> Self {
45 Self {
46 tracker,
47 model_provider_pricing,
48 turn_usage: Arc::new(Mutex::new(TurnUsage::default())),
49 agent_alias: None,
50 }
51 }
52
53 #[must_use]
56 pub fn with_agent_alias(mut self, agent_alias: impl Into<String>) -> Self {
57 self.agent_alias = Some(agent_alias.into());
58 self
59 }
60
61 pub fn snapshot_turn_usage(&self) -> TurnUsage {
64 *self.turn_usage.lock()
65 }
66}
67
68tokio::task_local! {
69 pub static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
70}
71
72fn resolve_rates(pricing: &HashMap<String, f64>, model: &str) -> (f64, f64, f64) {
87 let try_lookup = |key: &str| -> Option<(Option<f64>, Option<f64>, Option<f64>)> {
88 let input = pricing.get(&format!("{key}.input")).copied();
89 let output = pricing.get(&format!("{key}.output")).copied();
90 let cached = pricing.get(&format!("{key}.cached_input")).copied();
91 let flat = pricing.get(key).copied();
92 if input.is_none() && output.is_none() && cached.is_none() && flat.is_none() {
93 None
94 } else {
95 Some((input.or(flat), output.or(flat), cached))
96 }
97 };
98
99 if let Some((input, output, cached)) = try_lookup(model) {
100 return (
101 input.unwrap_or(0.0),
102 output.unwrap_or(0.0),
103 cached.unwrap_or(0.0),
104 );
105 }
106 if let Some((_, suffix)) = model.rsplit_once('/')
107 && let Some((input, output, cached)) = try_lookup(suffix)
108 {
109 return (
110 input.unwrap_or(0.0),
111 output.unwrap_or(0.0),
112 cached.unwrap_or(0.0),
113 );
114 }
115 (0.0, 0.0, 0.0)
116}
117
118fn provider_pricing<'a>(
129 map: &'a ModelProviderPricing,
130 model_provider_name: &str,
131) -> Option<&'a HashMap<String, f64>> {
132 map.get(model_provider_name).or_else(|| {
133 model_provider_name
134 .split_once('.')
135 .and_then(|(provider_type, _alias)| map.get(provider_type))
136 })
137}
138
139pub fn record_tool_loop_cost_usage(
142 model_provider_name: &str,
143 model: &str,
144 usage: &zeroclaw_providers::traits::TokenUsage,
145) -> Option<(u64, f64)> {
146 let input_tokens = usage.input_tokens.unwrap_or(0);
147 let output_tokens = usage.output_tokens.unwrap_or(0);
148 let cached_input_tokens = usage.cached_input_tokens.unwrap_or(0);
149 let total_tokens = input_tokens.saturating_add(output_tokens);
150 if total_tokens == 0 {
151 return None;
152 }
153
154 let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
155 .try_with(Clone::clone)
156 .ok()
157 .flatten()?;
158 let pricing = provider_pricing(&ctx.model_provider_pricing, model_provider_name);
159 let (input_rate, output_rate, cached_rate) = pricing
160 .map(|map| resolve_rates(map, model))
161 .unwrap_or((0.0, 0.0, 0.0));
162
163 let cost_usage = CostTokenUsage::new(
164 model,
165 input_tokens,
166 output_tokens,
167 cached_input_tokens,
168 input_rate,
169 output_rate,
170 cached_rate,
171 );
172
173 if pricing.is_none() || (input_rate == 0.0 && output_rate == 0.0) {
180 warn_once_missing_pricing(model_provider_name, model);
181 }
182
183 if let Err(error) = ctx
184 .tracker
185 .record_usage_with_agent(cost_usage.clone(), ctx.agent_alias.as_deref())
186 {
187 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": model_provider_name, "model": model, "error": format!("{}", error)})), "Failed to record cost tracking usage: ");
188 }
189
190 {
191 let mut usage = ctx.turn_usage.lock();
192 usage.input_tokens = usage.input_tokens.saturating_add(input_tokens);
193 usage.output_tokens = usage.output_tokens.saturating_add(output_tokens);
194 usage.cost_usd += cost_usage.cost_usd;
195 }
196
197 Some((cost_usage.total_tokens, cost_usage.cost_usd))
198}
199
200fn missing_pricing_first_sighting(
205 seen: &Mutex<HashSet<(String, String)>>,
206 model_provider: &str,
207 model: &str,
208) -> bool {
209 seen.lock()
210 .insert((model_provider.to_string(), model.to_string()))
211}
212
213fn warn_once_missing_pricing(model_provider: &str, model: &str) {
223 static SEEN: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
224 let seen = SEEN.get_or_init(|| Mutex::new(HashSet::new()));
225 if missing_pricing_first_sighting(seen, model_provider, model) {
226 ::zeroclaw_log::record!(
227 WARN,
228 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
229 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
230 .with_attrs(
231 ::serde_json::json!({"model_provider": model_provider, "model": model})
232 ),
233 "Cost tracking: no pricing entry found for {model_provider}/{model} — \
234 token usage will be recorded with zero cost and budget enforcement \
235 is inert for this model. Add a `pricing` table to the model provider \
236 entry in config.toml (under `[providers.models.\"{model_provider}\"]`) \
237 with `\"{model}.input\"` and `\"{model}.output\"` keys (USD per 1M tokens). \
238 This warning fires once per (model_provider, model) pair per process."
239 );
240 } else {
241 ::zeroclaw_log::record!(
242 DEBUG,
243 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(
244 ::serde_json::json!({"model_provider": model_provider, "model": model})
245 ),
246 "Cost tracking recorded token usage with zero pricing (no pricing entry found)"
247 );
248 }
249}
250
251pub fn check_tool_loop_budget() -> Option<BudgetCheck> {
254 TOOL_LOOP_COST_TRACKING_CONTEXT
255 .try_with(Clone::clone)
256 .ok()
257 .flatten()
258 .map(|ctx| {
259 ctx.tracker
260 .check_budget(0.0)
261 .unwrap_or(BudgetCheck::Allowed)
262 })
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 fn fresh_seen() -> Mutex<HashSet<(String, String)>> {
270 Mutex::new(HashSet::new())
271 }
272
273 #[test]
274 fn first_sighting_returns_true() {
275 let seen = fresh_seen();
276 assert!(
277 missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M2.7"),
278 "first observation of a (model_provider, model) pair must report first-sighting"
279 );
280 }
281
282 #[test]
283 fn second_sighting_same_pair_returns_false() {
284 let seen = fresh_seen();
285 assert!(missing_pricing_first_sighting(
286 &seen,
287 "minimax",
288 "MiniMax-M2.7"
289 ));
290 assert!(
291 !missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M2.7"),
292 "second sighting of the same pair must NOT re-fire WARN"
293 );
294 }
295
296 #[test]
297 fn different_models_under_same_provider_are_independent() {
298 let seen = fresh_seen();
299 assert!(missing_pricing_first_sighting(
300 &seen,
301 "minimax",
302 "MiniMax-M2.7"
303 ));
304 assert!(
305 missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M3.0"),
306 "different model under same model_provider is a distinct pair"
307 );
308 }
309
310 #[test]
311 fn provider_pricing_resolves_composite_and_bare_type_keys() {
312 let mut model_rates: HashMap<String, f64> = HashMap::new();
313 model_rates.insert("glm-5.1.input".to_string(), 1.4);
314 model_rates.insert("glm-5.1.output".to_string(), 4.4);
315
316 let mut composite_keyed: ModelProviderPricing = HashMap::new();
318 composite_keyed.insert("glm.default".to_string(), model_rates.clone());
319 assert!(
320 provider_pricing(&composite_keyed, "glm.default").is_some(),
321 "composite-keyed map must resolve via the verbatim composite lookup"
322 );
323
324 let mut type_keyed: ModelProviderPricing = HashMap::new();
327 type_keyed.insert("glm".to_string(), model_rates.clone());
328 assert!(
329 provider_pricing(&type_keyed, "glm.default").is_some(),
330 "type-keyed map must resolve the composite alias via the bare-type fallback"
331 );
332
333 assert!(
335 provider_pricing(&type_keyed, "openai.default").is_none(),
336 "fallback must not resolve a provider type absent from the map"
337 );
338 }
339
340 #[test]
341 fn different_providers_for_same_model_are_independent() {
342 let seen = fresh_seen();
345 assert!(missing_pricing_first_sighting(
346 &seen,
347 "openrouter",
348 "anthropic/claude-sonnet-4-5"
349 ));
350 assert!(
351 missing_pricing_first_sighting(&seen, "anthropic", "anthropic/claude-sonnet-4-5"),
352 "different model_provider for the same model is a distinct pair"
353 );
354 }
355
356 #[test]
357 fn empty_strings_dedup_independently() {
358 let seen = fresh_seen();
360 assert!(missing_pricing_first_sighting(&seen, "", "model"));
361 assert!(missing_pricing_first_sighting(&seen, "model_provider", ""));
362 assert!(missing_pricing_first_sighting(&seen, "", ""));
363 assert!(!missing_pricing_first_sighting(&seen, "", ""));
364 }
365}