Skip to main content

zeroclaw_runtime/agent/
cost.rs

1use crate::cost::CostTracker;
2use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
3use parking_lot::Mutex;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, OnceLock};
6
7// ── Cost tracking via task-local ──
8
9/// Per-provider pricing snapshot consumed by the cost tracker.
10///
11/// Outer key: model provider alias (e.g. `openrouter`, `anthropic`,
12/// `azure-openai`). Inner key: user-defined model identifier, optionally
13/// suffixed with `.input` / `.output` to encode pricing dimension. Values
14/// are USD per 1M tokens.
15pub type ModelProviderPricing = HashMap<String, HashMap<String, f64>>;
16
17/// Per-scope token/cost accumulator. Records pushed by
18/// `record_tool_loop_cost_usage` alongside the shared `CostTracker` so the
19/// wrapping code can read out the total for *this* call after the scope
20/// exits, without racing concurrent requests sharing the same tracker.
21#[derive(Default, Clone, Copy, Debug)]
22pub struct TurnUsage {
23    pub input_tokens: u64,
24    pub output_tokens: u64,
25    pub cost_usd: f64,
26}
27
28/// Context for cost tracking within the tool call loop.
29/// Scoped via `tokio::task_local!` at call sites (channels, gateway).
30#[derive(Clone)]
31pub struct ToolLoopCostTrackingContext {
32    pub tracker: Arc<CostTracker>,
33    pub model_provider_pricing: Arc<ModelProviderPricing>,
34    pub turn_usage: Arc<Mutex<TurnUsage>>,
35    /// Alias of the agent driving this turn. Stamped onto persisted
36    /// `CostRecord`s so `/api/cost?agent=<alias>` can attribute spend.
37    pub agent_alias: Option<String>,
38}
39
40impl ToolLoopCostTrackingContext {
41    pub fn new(
42        tracker: Arc<CostTracker>,
43        model_provider_pricing: Arc<ModelProviderPricing>,
44    ) -> Self {
45        Self {
46            tracker,
47            model_provider_pricing,
48            turn_usage: Arc::new(Mutex::new(TurnUsage::default())),
49            agent_alias: None,
50        }
51    }
52
53    /// Attach an agent alias to this context so subsequent
54    /// `record_tool_loop_cost_usage` calls stamp records with it.
55    #[must_use]
56    pub fn with_agent_alias(mut self, agent_alias: impl Into<String>) -> Self {
57        self.agent_alias = Some(agent_alias.into());
58        self
59    }
60
61    /// Snapshot the per-scope usage. Wrapping code calls this after the
62    /// scoped future completes to populate observer-event annotations.
63    pub fn snapshot_turn_usage(&self) -> TurnUsage {
64        *self.turn_usage.lock()
65    }
66}
67
68tokio::task_local! {
69    pub static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
70}
71
72/// Resolve `(input, output, cached_input)` per-1M-token rates for a given
73/// model on a model provider's pricing map. Lookup order:
74///
75/// 1. Dimension-specific keys: `{model}.input` / `{model}.output` /
76///    `{model}.cached_input`.
77/// 2. Bare model key as a flat fallback applied to whichever dimension
78///    didn't match in step 1.
79/// 3. The model alias path's last segment (`.../suffix`) tried under the
80///    same rules.
81///
82/// Returns `(0.0, 0.0, 0.0)` if no entry matches; the caller logs a
83/// one-shot warn in that case. A zero `cached_input` rate means "no
84/// discount" — the per-token caller bills the cached subset at the
85/// standard input rate.
86fn resolve_rates(pricing: &HashMap<String, f64>, model: &str) -> (f64, f64, f64) {
87    let try_lookup = |key: &str| -> Option<(Option<f64>, Option<f64>, Option<f64>)> {
88        let input = pricing.get(&format!("{key}.input")).copied();
89        let output = pricing.get(&format!("{key}.output")).copied();
90        let cached = pricing.get(&format!("{key}.cached_input")).copied();
91        let flat = pricing.get(key).copied();
92        if input.is_none() && output.is_none() && cached.is_none() && flat.is_none() {
93            None
94        } else {
95            Some((input.or(flat), output.or(flat), cached))
96        }
97    };
98
99    if let Some((input, output, cached)) = try_lookup(model) {
100        return (
101            input.unwrap_or(0.0),
102            output.unwrap_or(0.0),
103            cached.unwrap_or(0.0),
104        );
105    }
106    if let Some((_, suffix)) = model.rsplit_once('/')
107        && let Some((input, output, cached)) = try_lookup(suffix)
108    {
109        return (
110            input.unwrap_or(0.0),
111            output.unwrap_or(0.0),
112            cached.unwrap_or(0.0),
113        );
114    }
115    (0.0, 0.0, 0.0)
116}
117
118/// Record token usage from an LLM response via the task-local cost tracker.
119/// Returns `(total_tokens, cost_usd)` on success, `None` when not scoped or no usage.
120pub fn record_tool_loop_cost_usage(
121    model_provider_name: &str,
122    model: &str,
123    usage: &zeroclaw_providers::traits::TokenUsage,
124) -> Option<(u64, f64)> {
125    let input_tokens = usage.input_tokens.unwrap_or(0);
126    let output_tokens = usage.output_tokens.unwrap_or(0);
127    let cached_input_tokens = usage.cached_input_tokens.unwrap_or(0);
128    let total_tokens = input_tokens.saturating_add(output_tokens);
129    if total_tokens == 0 {
130        return None;
131    }
132
133    let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
134        .try_with(Clone::clone)
135        .ok()
136        .flatten()?;
137    let pricing = ctx.model_provider_pricing.get(model_provider_name);
138    let (input_rate, output_rate, cached_rate) = pricing
139        .map(|map| resolve_rates(map, model))
140        .unwrap_or((0.0, 0.0, 0.0));
141
142    let cost_usage = CostTokenUsage::new(
143        model,
144        input_tokens,
145        output_tokens,
146        cached_input_tokens,
147        input_rate,
148        output_rate,
149        cached_rate,
150    );
151
152    // Promote first sighting of (model_provider, model) without pricing to a WARN
153    // so operators notice the silent zero-cost record before they need to
154    // grep DEBUG logs. Subsequent sightings stay at DEBUG so the warn
155    // stream doesn't get spammy. Missing pricing means either the
156    // model_provider has no pricing map at all, or the map exists but
157    // produced zero rates for this model.
158    if pricing.is_none() || (input_rate == 0.0 && output_rate == 0.0) {
159        warn_once_missing_pricing(model_provider_name, model);
160    }
161
162    if let Err(error) = ctx
163        .tracker
164        .record_usage_with_agent(cost_usage.clone(), ctx.agent_alias.as_deref())
165    {
166        ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": model_provider_name, "model": model, "error": format!("{}", error)})), "Failed to record cost tracking usage: ");
167    }
168
169    {
170        let mut usage = ctx.turn_usage.lock();
171        usage.input_tokens = usage.input_tokens.saturating_add(input_tokens);
172        usage.output_tokens = usage.output_tokens.saturating_add(output_tokens);
173        usage.cost_usd += cost_usage.cost_usd;
174    }
175
176    Some((cost_usage.total_tokens, cost_usage.cost_usd))
177}
178
179/// Insert `(model_provider, model)` into `seen`. Returns `true` on first sighting,
180/// `false` thereafter. Split out from `warn_once_missing_pricing` so the
181/// dedup contract can be unit-tested with a caller-owned set instead of the
182/// process-static one.
183fn missing_pricing_first_sighting(
184    seen: &Mutex<HashSet<(String, String)>>,
185    model_provider: &str,
186    model: &str,
187) -> bool {
188    seen.lock()
189        .insert((model_provider.to_string(), model.to_string()))
190}
191
192/// First-time WARN, subsequent DEBUG, per `(model_provider, model)` pair.
193///
194/// The default pricing catalog has no entries for most non-OpenAI/Anthropic/
195/// Google models. Operators only realize their cost-tracking surface is
196/// reporting zero when they happen to enable DEBUG logging — a pure-DEBUG
197/// signal is too quiet for "your cost enforcement is silently inert" to
198/// register. Promote the first sighting per-pair to WARN with a config-path
199/// pointer; all subsequent same-pair occurrences stay at DEBUG so the warn
200/// stream doesn't get spammy.
201fn warn_once_missing_pricing(model_provider: &str, model: &str) {
202    static SEEN: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
203    let seen = SEEN.get_or_init(|| Mutex::new(HashSet::new()));
204    if missing_pricing_first_sighting(seen, model_provider, model) {
205        ::zeroclaw_log::record!(
206            WARN,
207            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
208                .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
209                .with_attrs(
210                    ::serde_json::json!({"model_provider": model_provider, "model": model})
211                ),
212            "Cost tracking: no pricing entry found for {model_provider}/{model} — \
213             token usage will be recorded with zero cost and budget enforcement \
214             is inert for this model. Add a `pricing` table to the model provider \
215             entry in config.toml (under `[model_providers.\"{model_provider}\"]`) \
216             with `\"{model}.input\"` and `\"{model}.output\"` keys (USD per 1M tokens). \
217             This warning fires once per (model_provider, model) pair per process."
218        );
219    } else {
220        ::zeroclaw_log::record!(
221            DEBUG,
222            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(
223                ::serde_json::json!({"model_provider": model_provider, "model": model})
224            ),
225            "Cost tracking recorded token usage with zero pricing (no pricing entry found)"
226        );
227    }
228}
229
230/// Check budget before an LLM call. Returns `None` when no cost tracking
231/// context is scoped (tests, delegate, CLI without cost config).
232pub fn check_tool_loop_budget() -> Option<BudgetCheck> {
233    TOOL_LOOP_COST_TRACKING_CONTEXT
234        .try_with(Clone::clone)
235        .ok()
236        .flatten()
237        .map(|ctx| {
238            ctx.tracker
239                .check_budget(0.0)
240                .unwrap_or(BudgetCheck::Allowed)
241        })
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    fn fresh_seen() -> Mutex<HashSet<(String, String)>> {
249        Mutex::new(HashSet::new())
250    }
251
252    #[test]
253    fn first_sighting_returns_true() {
254        let seen = fresh_seen();
255        assert!(
256            missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M2.7"),
257            "first observation of a (model_provider, model) pair must report first-sighting"
258        );
259    }
260
261    #[test]
262    fn second_sighting_same_pair_returns_false() {
263        let seen = fresh_seen();
264        assert!(missing_pricing_first_sighting(
265            &seen,
266            "minimax",
267            "MiniMax-M2.7"
268        ));
269        assert!(
270            !missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M2.7"),
271            "second sighting of the same pair must NOT re-fire WARN"
272        );
273    }
274
275    #[test]
276    fn different_models_under_same_provider_are_independent() {
277        let seen = fresh_seen();
278        assert!(missing_pricing_first_sighting(
279            &seen,
280            "minimax",
281            "MiniMax-M2.7"
282        ));
283        assert!(
284            missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M3.0"),
285            "different model under same model_provider is a distinct pair"
286        );
287    }
288
289    #[test]
290    fn different_providers_for_same_model_are_independent() {
291        // Same model name served by two different model_providers — operator may
292        // configure them at different rates, so the warn must fire for each.
293        let seen = fresh_seen();
294        assert!(missing_pricing_first_sighting(
295            &seen,
296            "openrouter",
297            "anthropic/claude-sonnet-4-5"
298        ));
299        assert!(
300            missing_pricing_first_sighting(&seen, "anthropic", "anthropic/claude-sonnet-4-5"),
301            "different model_provider for the same model is a distinct pair"
302        );
303    }
304
305    #[test]
306    fn empty_strings_dedup_independently() {
307        // Defensive: empty model_provider or model shouldn't collide with each other.
308        let seen = fresh_seen();
309        assert!(missing_pricing_first_sighting(&seen, "", "model"));
310        assert!(missing_pricing_first_sighting(&seen, "model_provider", ""));
311        assert!(missing_pricing_first_sighting(&seen, "", ""));
312        assert!(!missing_pricing_first_sighting(&seen, "", ""));
313    }
314}