1use crate::cost::CostTracker;
2use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
3use parking_lot::Mutex;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, OnceLock};
6
7pub type ModelProviderPricing = HashMap<String, HashMap<String, f64>>;
16
17#[derive(Default, Clone, Copy, Debug)]
22pub struct TurnUsage {
23 pub input_tokens: u64,
24 pub output_tokens: u64,
25 pub cost_usd: f64,
26}
27
28#[derive(Clone)]
31pub struct ToolLoopCostTrackingContext {
32 pub tracker: Arc<CostTracker>,
33 pub model_provider_pricing: Arc<ModelProviderPricing>,
34 pub turn_usage: Arc<Mutex<TurnUsage>>,
35 pub agent_alias: Option<String>,
38}
39
40impl ToolLoopCostTrackingContext {
41 pub fn new(
42 tracker: Arc<CostTracker>,
43 model_provider_pricing: Arc<ModelProviderPricing>,
44 ) -> Self {
45 Self {
46 tracker,
47 model_provider_pricing,
48 turn_usage: Arc::new(Mutex::new(TurnUsage::default())),
49 agent_alias: None,
50 }
51 }
52
53 #[must_use]
56 pub fn with_agent_alias(mut self, agent_alias: impl Into<String>) -> Self {
57 self.agent_alias = Some(agent_alias.into());
58 self
59 }
60
61 pub fn snapshot_turn_usage(&self) -> TurnUsage {
64 *self.turn_usage.lock()
65 }
66}
67
68tokio::task_local! {
69 pub static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
70}
71
72fn resolve_rates(pricing: &HashMap<String, f64>, model: &str) -> (f64, f64, f64) {
87 let try_lookup = |key: &str| -> Option<(Option<f64>, Option<f64>, Option<f64>)> {
88 let input = pricing.get(&format!("{key}.input")).copied();
89 let output = pricing.get(&format!("{key}.output")).copied();
90 let cached = pricing.get(&format!("{key}.cached_input")).copied();
91 let flat = pricing.get(key).copied();
92 if input.is_none() && output.is_none() && cached.is_none() && flat.is_none() {
93 None
94 } else {
95 Some((input.or(flat), output.or(flat), cached))
96 }
97 };
98
99 if let Some((input, output, cached)) = try_lookup(model) {
100 return (
101 input.unwrap_or(0.0),
102 output.unwrap_or(0.0),
103 cached.unwrap_or(0.0),
104 );
105 }
106 if let Some((_, suffix)) = model.rsplit_once('/')
107 && let Some((input, output, cached)) = try_lookup(suffix)
108 {
109 return (
110 input.unwrap_or(0.0),
111 output.unwrap_or(0.0),
112 cached.unwrap_or(0.0),
113 );
114 }
115 (0.0, 0.0, 0.0)
116}
117
118pub fn record_tool_loop_cost_usage(
121 model_provider_name: &str,
122 model: &str,
123 usage: &zeroclaw_providers::traits::TokenUsage,
124) -> Option<(u64, f64)> {
125 let input_tokens = usage.input_tokens.unwrap_or(0);
126 let output_tokens = usage.output_tokens.unwrap_or(0);
127 let cached_input_tokens = usage.cached_input_tokens.unwrap_or(0);
128 let total_tokens = input_tokens.saturating_add(output_tokens);
129 if total_tokens == 0 {
130 return None;
131 }
132
133 let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
134 .try_with(Clone::clone)
135 .ok()
136 .flatten()?;
137 let pricing = ctx.model_provider_pricing.get(model_provider_name);
138 let (input_rate, output_rate, cached_rate) = pricing
139 .map(|map| resolve_rates(map, model))
140 .unwrap_or((0.0, 0.0, 0.0));
141
142 let cost_usage = CostTokenUsage::new(
143 model,
144 input_tokens,
145 output_tokens,
146 cached_input_tokens,
147 input_rate,
148 output_rate,
149 cached_rate,
150 );
151
152 if pricing.is_none() || (input_rate == 0.0 && output_rate == 0.0) {
159 warn_once_missing_pricing(model_provider_name, model);
160 }
161
162 if let Err(error) = ctx
163 .tracker
164 .record_usage_with_agent(cost_usage.clone(), ctx.agent_alias.as_deref())
165 {
166 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": model_provider_name, "model": model, "error": format!("{}", error)})), "Failed to record cost tracking usage: ");
167 }
168
169 {
170 let mut usage = ctx.turn_usage.lock();
171 usage.input_tokens = usage.input_tokens.saturating_add(input_tokens);
172 usage.output_tokens = usage.output_tokens.saturating_add(output_tokens);
173 usage.cost_usd += cost_usage.cost_usd;
174 }
175
176 Some((cost_usage.total_tokens, cost_usage.cost_usd))
177}
178
179fn missing_pricing_first_sighting(
184 seen: &Mutex<HashSet<(String, String)>>,
185 model_provider: &str,
186 model: &str,
187) -> bool {
188 seen.lock()
189 .insert((model_provider.to_string(), model.to_string()))
190}
191
192fn warn_once_missing_pricing(model_provider: &str, model: &str) {
202 static SEEN: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
203 let seen = SEEN.get_or_init(|| Mutex::new(HashSet::new()));
204 if missing_pricing_first_sighting(seen, model_provider, model) {
205 ::zeroclaw_log::record!(
206 WARN,
207 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
208 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
209 .with_attrs(
210 ::serde_json::json!({"model_provider": model_provider, "model": model})
211 ),
212 "Cost tracking: no pricing entry found for {model_provider}/{model} — \
213 token usage will be recorded with zero cost and budget enforcement \
214 is inert for this model. Add a `pricing` table to the model provider \
215 entry in config.toml (under `[model_providers.\"{model_provider}\"]`) \
216 with `\"{model}.input\"` and `\"{model}.output\"` keys (USD per 1M tokens). \
217 This warning fires once per (model_provider, model) pair per process."
218 );
219 } else {
220 ::zeroclaw_log::record!(
221 DEBUG,
222 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(
223 ::serde_json::json!({"model_provider": model_provider, "model": model})
224 ),
225 "Cost tracking recorded token usage with zero pricing (no pricing entry found)"
226 );
227 }
228}
229
230pub fn check_tool_loop_budget() -> Option<BudgetCheck> {
233 TOOL_LOOP_COST_TRACKING_CONTEXT
234 .try_with(Clone::clone)
235 .ok()
236 .flatten()
237 .map(|ctx| {
238 ctx.tracker
239 .check_budget(0.0)
240 .unwrap_or(BudgetCheck::Allowed)
241 })
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 fn fresh_seen() -> Mutex<HashSet<(String, String)>> {
249 Mutex::new(HashSet::new())
250 }
251
252 #[test]
253 fn first_sighting_returns_true() {
254 let seen = fresh_seen();
255 assert!(
256 missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M2.7"),
257 "first observation of a (model_provider, model) pair must report first-sighting"
258 );
259 }
260
261 #[test]
262 fn second_sighting_same_pair_returns_false() {
263 let seen = fresh_seen();
264 assert!(missing_pricing_first_sighting(
265 &seen,
266 "minimax",
267 "MiniMax-M2.7"
268 ));
269 assert!(
270 !missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M2.7"),
271 "second sighting of the same pair must NOT re-fire WARN"
272 );
273 }
274
275 #[test]
276 fn different_models_under_same_provider_are_independent() {
277 let seen = fresh_seen();
278 assert!(missing_pricing_first_sighting(
279 &seen,
280 "minimax",
281 "MiniMax-M2.7"
282 ));
283 assert!(
284 missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M3.0"),
285 "different model under same model_provider is a distinct pair"
286 );
287 }
288
289 #[test]
290 fn different_providers_for_same_model_are_independent() {
291 let seen = fresh_seen();
294 assert!(missing_pricing_first_sighting(
295 &seen,
296 "openrouter",
297 "anthropic/claude-sonnet-4-5"
298 ));
299 assert!(
300 missing_pricing_first_sighting(&seen, "anthropic", "anthropic/claude-sonnet-4-5"),
301 "different model_provider for the same model is a distinct pair"
302 );
303 }
304
305 #[test]
306 fn empty_strings_dedup_independently() {
307 let seen = fresh_seen();
309 assert!(missing_pricing_first_sighting(&seen, "", "model"));
310 assert!(missing_pricing_first_sighting(&seen, "model_provider", ""));
311 assert!(missing_pricing_first_sighting(&seen, "", ""));
312 assert!(!missing_pricing_first_sighting(&seen, "", ""));
313 }
314}