Skip to main content

zeroclaw_runtime/agent/
cost.rs

1use crate::cost::CostTracker;
2use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
3use parking_lot::Mutex;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, OnceLock};
6
7// ── Cost tracking via task-local ──
8
9/// Per-provider pricing snapshot consumed by the cost tracker.
10///
11/// Outer key: model provider alias (e.g. `openrouter`, `anthropic`,
12/// `azure-openai`). Inner key: user-defined model identifier, optionally
13/// suffixed with `.input` / `.output` to encode pricing dimension. Values
14/// are USD per 1M tokens.
15pub type ModelProviderPricing = HashMap<String, HashMap<String, f64>>;
16
17/// Per-scope token/cost accumulator. Records pushed by
18/// `record_tool_loop_cost_usage` alongside the shared `CostTracker` so the
19/// wrapping code can read out the total for *this* call after the scope
20/// exits, without racing concurrent requests sharing the same tracker.
21#[derive(Default, Clone, Copy, Debug)]
22pub struct TurnUsage {
23    pub input_tokens: u64,
24    pub output_tokens: u64,
25    pub cost_usd: f64,
26}
27
28/// Context for cost tracking within the tool call loop.
29/// Scoped via `tokio::task_local!` at call sites (channels, gateway).
30#[derive(Clone)]
31pub struct ToolLoopCostTrackingContext {
32    pub tracker: Arc<CostTracker>,
33    pub model_provider_pricing: Arc<ModelProviderPricing>,
34    pub turn_usage: Arc<Mutex<TurnUsage>>,
35    /// Alias of the agent driving this turn. Stamped onto persisted
36    /// `CostRecord`s so `/api/cost?agent=<alias>` can attribute spend.
37    pub agent_alias: Option<String>,
38}
39
40impl ToolLoopCostTrackingContext {
41    pub fn new(
42        tracker: Arc<CostTracker>,
43        model_provider_pricing: Arc<ModelProviderPricing>,
44    ) -> Self {
45        Self {
46            tracker,
47            model_provider_pricing,
48            turn_usage: Arc::new(Mutex::new(TurnUsage::default())),
49            agent_alias: None,
50        }
51    }
52
53    /// Attach an agent alias to this context so subsequent
54    /// `record_tool_loop_cost_usage` calls stamp records with it.
55    #[must_use]
56    pub fn with_agent_alias(mut self, agent_alias: impl Into<String>) -> Self {
57        self.agent_alias = Some(agent_alias.into());
58        self
59    }
60
61    /// Snapshot the per-scope usage. Wrapping code calls this after the
62    /// scoped future completes to populate observer-event annotations.
63    pub fn snapshot_turn_usage(&self) -> TurnUsage {
64        *self.turn_usage.lock()
65    }
66}
67
68tokio::task_local! {
69    pub static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
70}
71
72/// Resolve `(input, output, cached_input)` per-1M-token rates for a given
73/// model on a model provider's pricing map. Lookup order:
74///
75/// 1. Dimension-specific keys: `{model}.input` / `{model}.output` /
76///    `{model}.cached_input`.
77/// 2. Bare model key as a flat fallback applied to whichever dimension
78///    didn't match in step 1.
79/// 3. The model alias path's last segment (`.../suffix`) tried under the
80///    same rules.
81///
82/// Returns `(0.0, 0.0, 0.0)` if no entry matches; the caller logs a
83/// one-shot warn in that case. A zero `cached_input` rate means "no
84/// discount" — the per-token caller bills the cached subset at the
85/// standard input rate.
86fn resolve_rates(pricing: &HashMap<String, f64>, model: &str) -> (f64, f64, f64) {
87    let try_lookup = |key: &str| -> Option<(Option<f64>, Option<f64>, Option<f64>)> {
88        let input = pricing.get(&format!("{key}.input")).copied();
89        let output = pricing.get(&format!("{key}.output")).copied();
90        let cached = pricing.get(&format!("{key}.cached_input")).copied();
91        let flat = pricing.get(key).copied();
92        if input.is_none() && output.is_none() && cached.is_none() && flat.is_none() {
93            None
94        } else {
95            Some((input.or(flat), output.or(flat), cached))
96        }
97    };
98
99    if let Some((input, output, cached)) = try_lookup(model) {
100        return (
101            input.unwrap_or(0.0),
102            output.unwrap_or(0.0),
103            cached.unwrap_or(0.0),
104        );
105    }
106    if let Some((_, suffix)) = model.rsplit_once('/')
107        && let Some((input, output, cached)) = try_lookup(suffix)
108    {
109        return (
110            input.unwrap_or(0.0),
111            output.unwrap_or(0.0),
112            cached.unwrap_or(0.0),
113        );
114    }
115    (0.0, 0.0, 0.0)
116}
117
118/// Resolve the per-model pricing map for a provider reference.
119///
120/// `model_provider_name` always arrives as the composite `<type>.<alias>`
121/// (see `agent_provider_composite`), but the outer pricing map may be keyed
122/// either way depending on which builder populated it: the CLI / cron / web
123/// agent loop keys by the composite alias, while the channel orchestrator keys
124/// by the bare provider `<type>` (rates are per provider type, not per alias).
125/// Try the composite verbatim first, then fall back to the bare type prefix so
126/// cost tracking resolves regardless of the builder — and so the type-keyed
127/// `cost.rates` sheet is honored on the alias paths too.
128fn provider_pricing<'a>(
129    map: &'a ModelProviderPricing,
130    model_provider_name: &str,
131) -> Option<&'a HashMap<String, f64>> {
132    map.get(model_provider_name).or_else(|| {
133        model_provider_name
134            .split_once('.')
135            .and_then(|(provider_type, _alias)| map.get(provider_type))
136    })
137}
138
139/// Record token usage from an LLM response via the task-local cost tracker.
140/// Returns `(total_tokens, cost_usd)` on success, `None` when not scoped or no usage.
141pub fn record_tool_loop_cost_usage(
142    model_provider_name: &str,
143    model: &str,
144    usage: &zeroclaw_providers::traits::TokenUsage,
145) -> Option<(u64, f64)> {
146    let input_tokens = usage.input_tokens.unwrap_or(0);
147    let output_tokens = usage.output_tokens.unwrap_or(0);
148    let cached_input_tokens = usage.cached_input_tokens.unwrap_or(0);
149    let total_tokens = input_tokens.saturating_add(output_tokens);
150    if total_tokens == 0 {
151        return None;
152    }
153
154    let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
155        .try_with(Clone::clone)
156        .ok()
157        .flatten()?;
158    let pricing = provider_pricing(&ctx.model_provider_pricing, model_provider_name);
159    let (input_rate, output_rate, cached_rate) = pricing
160        .map(|map| resolve_rates(map, model))
161        .unwrap_or((0.0, 0.0, 0.0));
162
163    let cost_usage = CostTokenUsage::new(
164        model,
165        input_tokens,
166        output_tokens,
167        cached_input_tokens,
168        input_rate,
169        output_rate,
170        cached_rate,
171    );
172
173    // Promote first sighting of (model_provider, model) without pricing to a WARN
174    // so operators notice the silent zero-cost record before they need to
175    // grep DEBUG logs. Subsequent sightings stay at DEBUG so the warn
176    // stream doesn't get spammy. Missing pricing means either the
177    // model_provider has no pricing map at all, or the map exists but
178    // produced zero rates for this model.
179    if pricing.is_none() || (input_rate == 0.0 && output_rate == 0.0) {
180        warn_once_missing_pricing(model_provider_name, model);
181    }
182
183    if let Err(error) = ctx
184        .tracker
185        .record_usage_with_agent(cost_usage.clone(), ctx.agent_alias.as_deref())
186    {
187        ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": model_provider_name, "model": model, "error": format!("{}", error)})), "Failed to record cost tracking usage: ");
188    }
189
190    {
191        let mut usage = ctx.turn_usage.lock();
192        usage.input_tokens = usage.input_tokens.saturating_add(input_tokens);
193        usage.output_tokens = usage.output_tokens.saturating_add(output_tokens);
194        usage.cost_usd += cost_usage.cost_usd;
195    }
196
197    Some((cost_usage.total_tokens, cost_usage.cost_usd))
198}
199
200/// Insert `(model_provider, model)` into `seen`. Returns `true` on first sighting,
201/// `false` thereafter. Split out from `warn_once_missing_pricing` so the
202/// dedup contract can be unit-tested with a caller-owned set instead of the
203/// process-static one.
204fn missing_pricing_first_sighting(
205    seen: &Mutex<HashSet<(String, String)>>,
206    model_provider: &str,
207    model: &str,
208) -> bool {
209    seen.lock()
210        .insert((model_provider.to_string(), model.to_string()))
211}
212
213/// First-time WARN, subsequent DEBUG, per `(model_provider, model)` pair.
214///
215/// The default pricing catalog has no entries for most non-OpenAI/Anthropic/
216/// Google models. Operators only realize their cost-tracking surface is
217/// reporting zero when they happen to enable DEBUG logging — a pure-DEBUG
218/// signal is too quiet for "your cost enforcement is silently inert" to
219/// register. Promote the first sighting per-pair to WARN with a config-path
220/// pointer; all subsequent same-pair occurrences stay at DEBUG so the warn
221/// stream doesn't get spammy.
222fn warn_once_missing_pricing(model_provider: &str, model: &str) {
223    static SEEN: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
224    let seen = SEEN.get_or_init(|| Mutex::new(HashSet::new()));
225    if missing_pricing_first_sighting(seen, model_provider, model) {
226        ::zeroclaw_log::record!(
227            WARN,
228            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
229                .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
230                .with_attrs(
231                    ::serde_json::json!({"model_provider": model_provider, "model": model})
232                ),
233            "Cost tracking: no pricing entry found for {model_provider}/{model} — \
234             token usage will be recorded with zero cost and budget enforcement \
235             is inert for this model. Add a `pricing` table to the model provider \
236             entry in config.toml (under `[providers.models.\"{model_provider}\"]`) \
237             with `\"{model}.input\"` and `\"{model}.output\"` keys (USD per 1M tokens). \
238             This warning fires once per (model_provider, model) pair per process."
239        );
240    } else {
241        ::zeroclaw_log::record!(
242            DEBUG,
243            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(
244                ::serde_json::json!({"model_provider": model_provider, "model": model})
245            ),
246            "Cost tracking recorded token usage with zero pricing (no pricing entry found)"
247        );
248    }
249}
250
251/// Check budget before an LLM call. Returns `None` when no cost tracking
252/// context is scoped (tests, delegate, CLI without cost config).
253pub fn check_tool_loop_budget() -> Option<BudgetCheck> {
254    TOOL_LOOP_COST_TRACKING_CONTEXT
255        .try_with(Clone::clone)
256        .ok()
257        .flatten()
258        .map(|ctx| {
259            ctx.tracker
260                .check_budget(0.0)
261                .unwrap_or(BudgetCheck::Allowed)
262        })
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    fn fresh_seen() -> Mutex<HashSet<(String, String)>> {
270        Mutex::new(HashSet::new())
271    }
272
273    #[test]
274    fn first_sighting_returns_true() {
275        let seen = fresh_seen();
276        assert!(
277            missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M2.7"),
278            "first observation of a (model_provider, model) pair must report first-sighting"
279        );
280    }
281
282    #[test]
283    fn second_sighting_same_pair_returns_false() {
284        let seen = fresh_seen();
285        assert!(missing_pricing_first_sighting(
286            &seen,
287            "minimax",
288            "MiniMax-M2.7"
289        ));
290        assert!(
291            !missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M2.7"),
292            "second sighting of the same pair must NOT re-fire WARN"
293        );
294    }
295
296    #[test]
297    fn different_models_under_same_provider_are_independent() {
298        let seen = fresh_seen();
299        assert!(missing_pricing_first_sighting(
300            &seen,
301            "minimax",
302            "MiniMax-M2.7"
303        ));
304        assert!(
305            missing_pricing_first_sighting(&seen, "minimax", "MiniMax-M3.0"),
306            "different model under same model_provider is a distinct pair"
307        );
308    }
309
310    #[test]
311    fn provider_pricing_resolves_composite_and_bare_type_keys() {
312        let mut model_rates: HashMap<String, f64> = HashMap::new();
313        model_rates.insert("glm-5.1.input".to_string(), 1.4);
314        model_rates.insert("glm-5.1.output".to_string(), 4.4);
315
316        // CLI / agent-loop builder keys by the composite `<type>.<alias>`.
317        let mut composite_keyed: ModelProviderPricing = HashMap::new();
318        composite_keyed.insert("glm.default".to_string(), model_rates.clone());
319        assert!(
320            provider_pricing(&composite_keyed, "glm.default").is_some(),
321            "composite-keyed map must resolve via the verbatim composite lookup"
322        );
323
324        // Channel orchestrator builder keys by the bare provider `<type>`, yet
325        // the lookup still arrives as the composite alias — must fall back.
326        let mut type_keyed: ModelProviderPricing = HashMap::new();
327        type_keyed.insert("glm".to_string(), model_rates.clone());
328        assert!(
329            provider_pricing(&type_keyed, "glm.default").is_some(),
330            "type-keyed map must resolve the composite alias via the bare-type fallback"
331        );
332
333        // An unrelated provider must not accidentally match.
334        assert!(
335            provider_pricing(&type_keyed, "openai.default").is_none(),
336            "fallback must not resolve a provider type absent from the map"
337        );
338    }
339
340    #[test]
341    fn different_providers_for_same_model_are_independent() {
342        // Same model name served by two different model_providers — operator may
343        // configure them at different rates, so the warn must fire for each.
344        let seen = fresh_seen();
345        assert!(missing_pricing_first_sighting(
346            &seen,
347            "openrouter",
348            "anthropic/claude-sonnet-4-5"
349        ));
350        assert!(
351            missing_pricing_first_sighting(&seen, "anthropic", "anthropic/claude-sonnet-4-5"),
352            "different model_provider for the same model is a distinct pair"
353        );
354    }
355
356    #[test]
357    fn empty_strings_dedup_independently() {
358        // Defensive: empty model_provider or model shouldn't collide with each other.
359        let seen = fresh_seen();
360        assert!(missing_pricing_first_sighting(&seen, "", "model"));
361        assert!(missing_pricing_first_sighting(&seen, "model_provider", ""));
362        assert!(missing_pricing_first_sighting(&seen, "", ""));
363        assert!(!missing_pricing_first_sighting(&seen, "", ""));
364    }
365}