From 1c48c6e8f924bf58330a76b2f7f1a980d801c48d Mon Sep 17 00:00:00 2001 From: Sayt-0 Date: Fri, 26 Jun 2026 19:03:43 +0200 Subject: [PATCH] feat(cli): print session summary at end of exec runs Non-interactive runs previously emitted no usage recap; the TUI only had the on-demand /cost dialog. This adds an end-of-run summary covering the interaction (tool calls, success rate), per-model token usage, and total cost. The recap is derived from the existing session via a new Session.Stats() helper (tool-call success/failure counts, per-model token and cost breakdown, cache savings, duration). It is written to stderr so stdout stays machine-consumable, and is skipped in JSON output mode. Refs #1345 --- cmd/root/run.go | 9 +- pkg/cli/runner.go | 14 +++ pkg/cli/summary.go | 182 ++++++++++++++++++++++++++++++++++++ pkg/cli/summary_test.go | 115 +++++++++++++++++++++++ pkg/session/stats.go | 190 ++++++++++++++++++++++++++++++++++++++ pkg/session/stats_test.go | 158 +++++++++++++++++++++++++++++++ 6 files changed, 664 insertions(+), 4 deletions(-) create mode 100644 pkg/cli/summary.go create mode 100644 pkg/cli/summary_test.go create mode 100644 pkg/session/stats.go create mode 100644 pkg/session/stats_test.go diff --git a/cmd/root/run.go b/cmd/root/run.go index e6602ef641..c184fab169 100644 --- a/cmd/root/run.go +++ b/cmd/root/run.go @@ -287,10 +287,10 @@ func (f *runExecFlags) runRunCommand(cmd *cobra.Command, args []string) (command out := cli.NewPrinter(cmd.OutOrStdout()) - return f.runOrExec(ctx, out, args, useTUI) + return f.runOrExec(ctx, out, cmd.ErrOrStderr(), args, useTUI) } -func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []string, useTUI bool) error { +func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, errOut io.Writer, args []string, useTUI bool) error { slog.DebugContext(ctx, "Starting agent", "agent", f.agentName) // Start profiling if requested @@ -436,7 +436,7 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s // Non-interactive (--exec) runs never clean up the worktree: there // is no safe moment to prompt, and silently discarding work would // be surprising. The worktree is left in place for later inspection. - return f.handleExecMode(ctx, out, rt, sess, args) + return f.handleExecMode(ctx, out, errOut, rt, sess, args) } listenOpt, err := f.startAttachedServer(ctx, out, rt, sess) @@ -823,7 +823,7 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes return localRt, sess, nil } -func (f *runExecFlags) handleExecMode(ctx context.Context, out *cli.Printer, rt runtime.Runtime, sess *session.Session, args []string) error { +func (f *runExecFlags) handleExecMode(ctx context.Context, out *cli.Printer, errOut io.Writer, rt runtime.Runtime, sess *session.Session, args []string) error { if f.sessionReadOnly { return errors.New("--session-read-only cannot be used with --exec: there is nothing to display without a TUI") } @@ -840,6 +840,7 @@ func (f *runExecFlags) handleExecMode(ctx context.Context, out *cli.Printer, rt HideToolCalls: f.hideToolCalls, OutputJSON: f.outputJSON, AutoApprove: f.autoApprove, + SummaryOut: errOut, }, rt, sess, userMessages) if cliErr, ok := errors.AsType[cli.RuntimeError](err); ok { return RuntimeError{Err: cliErr.Err} diff --git a/pkg/cli/runner.go b/pkg/cli/runner.go index 29cd282052..f4a096ed4d 100644 --- a/pkg/cli/runner.go +++ b/pkg/cli/runner.go @@ -74,6 +74,12 @@ type Config struct { AutoApprove bool HideToolCalls bool OutputJSON bool + + // SummaryOut, when non-nil, receives the end-of-run recap (interaction + // summary, model usage, token/cost totals). It is kept separate from the + // agent's response stream so callers can route it to stderr and keep + // stdout machine-consumable. The recap is skipped in JSON output mode. + SummaryOut io.Writer } // Run executes an agent in non-TUI mode, handling user input and runtime events. @@ -311,6 +317,14 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess } } + // Print an end-of-run recap (interaction summary, model usage, token/cost + // totals) to the dedicated summary writer (typically stderr) so stdout stays + // machine-consumable. JSON mode emits machine-readable events instead, so it + // is skipped there. + if !cfg.OutputJSON && cfg.SummaryOut != nil { + NewPrinter(cfg.SummaryOut).PrintSessionSummary(sess.Stats()) + } + // Wrap runtime errors to prevent duplicate error messages and usage display if lastErr != nil { return RuntimeError{Err: lastErr} diff --git a/pkg/cli/summary.go b/pkg/cli/summary.go new file mode 100644 index 0000000000..5b46e2f94b --- /dev/null +++ b/pkg/cli/summary.go @@ -0,0 +1,182 @@ +package cli + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/docker/docker-agent/pkg/session" +) + +// PrintSessionSummary writes a compact end-of-run recap: an interaction +// summary (duration, tool calls, success rate), a per-model usage table, and +// token/cost totals with cache savings. It is a no-op when the session +// recorded no measurable activity. Callers skip it in JSON output mode. +func (p *Printer) PrintSessionSummary(stats session.Stats) { + if !stats.HasActivity() { + return + } + + p.Println() + p.Println(bold("─── Session summary ───")) + + var rows [][2]string + if stats.ID != "" { + rows = append(rows, [2]string{"Session", stats.ID}) + } + if stats.Duration > 0 { + rows = append(rows, [2]string{"Duration", formatDuration(stats.Duration)}) + } + if stats.Requests > 0 { + rows = append(rows, [2]string{"Requests", strconv.Itoa(stats.Requests)}) + } + if stats.ToolCalls > 0 { + rows = append(rows, + [2]string{"Tool calls", fmt.Sprintf("%d (%d ok, %d failed)", stats.ToolCalls, stats.ToolSuccesses(), stats.ToolErrors)}, + [2]string{"Success rate", fmt.Sprintf("%.1f%%", stats.SuccessRate())}, + ) + } + p.printLabeledRows(rows) + + if len(stats.Models) > 0 { + p.Println() + p.Println(bold("Model usage")) + p.printModelTable(stats.Models) + } + + p.Println() + totals := [][2]string{ + {"Tokens", fmt.Sprintf("%s in / %s out", groupDigits(stats.TotalInput()), groupDigits(stats.OutputTokens))}, + {"Cost", formatCost(stats.Cost)}, + } + if stats.CachedInput > 0 { + totals = append(totals, [2]string{"Cache", fmt.Sprintf("%s input tokens from cache (%.1f%%)", groupDigits(stats.CachedInput), stats.CacheHitRate())}) + } + p.printLabeledRows(totals) +} + +// printLabeledRows prints "label: value" pairs with the colon-terminated +// labels padded to a common width so the values line up. Padding is applied +// outside the color escape so it stays correct regardless of TTY coloring. +func (p *Printer) printLabeledRows(rows [][2]string) { + labelWidth := 0 + for _, r := range rows { + if l := len(r[0]) + 1; l > labelWidth { + labelWidth = l + } + } + for _, r := range rows { + label := r[0] + ":" + p.Printf("%s%s %s\n", bold(label), strings.Repeat(" ", labelWidth-len(label)), r[1]) + } +} + +// printModelTable renders the per-model usage breakdown as an aligned table. +func (p *Printer) printModelTable(models []session.ModelStats) { + headers := []string{"MODEL", "REQS", "INPUT", "CACHE READ", "OUTPUT"} + rows := make([][]string, 0, len(models)) + for _, m := range models { + rows = append(rows, []string{ + m.Model, + strconv.Itoa(m.Requests), + groupDigits(m.InputTokens), + groupDigits(m.CachedInput), + groupDigits(m.OutputTokens), + }) + } + + widths := make([]int, len(headers)) + for i, h := range headers { + widths[i] = len(h) + } + for _, row := range rows { + for i, c := range row { + if len(c) > widths[i] { + widths[i] = len(c) + } + } + } + + p.Println(bold(formatTableRow(headers, widths))) + for _, row := range rows { + p.Println(formatTableRow(row, widths)) + } +} + +// formatTableRow left-aligns the first column (model name) and right-aligns the +// numeric columns, separating them with a fixed gutter. +func formatTableRow(cols []string, widths []int) string { + var b strings.Builder + for i, c := range cols { + if i > 0 { + b.WriteString(" ") + } + pad := strings.Repeat(" ", widths[i]-len(c)) + if i == 0 { + b.WriteString(c) + b.WriteString(pad) + } else { + b.WriteString(pad) + b.WriteString(c) + } + } + return strings.TrimRight(b.String(), " ") +} + +// groupDigits formats an integer with thousands separators (e.g. 70637 -> +// "70,637"). +func groupDigits(n int64) string { + s := strconv.FormatInt(n, 10) + neg := strings.HasPrefix(s, "-") + if neg { + s = s[1:] + } + for i := len(s) - 3; i > 0; i -= 3 { + s = s[:i] + "," + s[i:] + } + if neg { + return "-" + s + } + return s +} + +// formatCost renders a dollar amount, widening the precision for sub-cent +// values so small costs are not all displayed as "$0.00". +func formatCost(cost float64) string { + if cost < 0 { + return "-" + formatCost(-cost) + } + switch { + case cost < 0.0001: + return "$0.00" + case cost < 0.01: + return fmt.Sprintf("$%.4f", cost) + default: + return fmt.Sprintf("$%.2f", cost) + } +} + +// formatDuration renders a duration in a compact, human-readable form. +func formatDuration(d time.Duration) string { + if d <= 0 { + return "0s" + } + if d < time.Minute { + return fmt.Sprintf("%ds", int(d.Seconds())) + } + if d < time.Hour { + m := int(d.Minutes()) + s := int(d.Seconds()) % 60 + if s == 0 { + return fmt.Sprintf("%dm", m) + } + return fmt.Sprintf("%dm %ds", m, s) + } + h := int(d.Hours()) + m := int(d.Minutes()) % 60 + if m == 0 { + return fmt.Sprintf("%dh", h) + } + return fmt.Sprintf("%dh %dm", h, m) +} diff --git a/pkg/cli/summary_test.go b/pkg/cli/summary_test.go new file mode 100644 index 0000000000..b9026202f5 --- /dev/null +++ b/pkg/cli/summary_test.go @@ -0,0 +1,115 @@ +package cli + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/docker/docker-agent/pkg/session" +) + +func renderSummary(stats session.Stats) string { + var buf bytes.Buffer + NewPrinter(&buf).PrintSessionSummary(stats) + return buf.String() +} + +func TestPrintSessionSummary_NoActivityIsNoOp(t *testing.T) { + t.Parallel() + + assert.Empty(t, renderSummary(session.Stats{ID: "abc"})) +} + +func TestPrintSessionSummary_FullRecap(t *testing.T) { + t.Parallel() + + out := renderSummary(session.Stats{ + ID: "abc-123", + Duration: 90 * time.Second, + Requests: 4, + ToolCalls: 12, + ToolErrors: 2, + InputTokens: 1000, + CachedInput: 146112, + OutputTokens: 2059, + Cost: 0.1234, + Models: []session.ModelStats{ + {Model: "gemini-2.5-pro", Requests: 12, InputTokens: 22264, CachedInput: 106844, OutputTokens: 237}, + }, + }) + + // Interaction summary. + assert.Contains(t, out, "Session summary") + assert.Contains(t, out, "abc-123") + assert.Contains(t, out, "1m 30s") + assert.Contains(t, out, "12 (10 ok, 2 failed)") + assert.Contains(t, out, "83.3%") + + // Model usage table. + assert.Contains(t, out, "Model usage") + assert.Contains(t, out, "gemini-2.5-pro") + assert.Contains(t, out, "22,264") + assert.Contains(t, out, "106,844") + + // Totals and cache savings. + assert.Contains(t, out, "147,112 in / 2,059 out") + assert.Contains(t, out, "$0.12") + assert.Contains(t, out, "146,112 input tokens from cache") +} + +func TestPrintSessionSummary_NoToolCallsOmitsToolLines(t *testing.T) { + t.Parallel() + + out := renderSummary(session.Stats{ + ID: "no-tools", + Requests: 1, + InputTokens: 500, + OutputTokens: 100, + Cost: 0.002, + Models: []session.ModelStats{{Model: "m", Requests: 1, InputTokens: 500, OutputTokens: 100}}, + }) + + assert.NotContains(t, out, "Tool calls") + assert.NotContains(t, out, "Success rate") + assert.Contains(t, out, "500 in / 100 out") +} + +func TestGroupDigits(t *testing.T) { + t.Parallel() + + cases := map[int64]string{ + 0: "0", + 5: "5", + 999: "999", + 1000: "1,000", + 70637: "70,637", + 1234567: "1,234,567", + -1500: "-1,500", + } + for in, want := range cases { + assert.Equalf(t, want, groupDigits(in), "groupDigits(%d)", in) + } +} + +func TestFormatCost(t *testing.T) { + t.Parallel() + + assert.Equal(t, "$0.00", formatCost(0)) + assert.Equal(t, "$0.00", formatCost(0.00005)) + assert.Equal(t, "$0.0050", formatCost(0.005)) + assert.Equal(t, "$0.12", formatCost(0.1234)) + assert.Equal(t, "$3.00", formatCost(3)) +} + +func TestFormatDuration(t *testing.T) { + t.Parallel() + + assert.Equal(t, "0s", formatDuration(0)) + assert.Equal(t, "45s", formatDuration(45*time.Second)) + assert.Equal(t, "1m", formatDuration(time.Minute)) + assert.Equal(t, "1m 30s", formatDuration(90*time.Second)) + assert.Equal(t, "1h", formatDuration(time.Hour)) + assert.Equal(t, "1h 1m", formatDuration(time.Hour+time.Minute)) +} diff --git a/pkg/session/stats.go b/pkg/session/stats.go new file mode 100644 index 0000000000..a99c8ddb8b --- /dev/null +++ b/pkg/session/stats.go @@ -0,0 +1,190 @@ +package session + +import ( + "cmp" + "slices" + "time" + + "github.com/docker/docker-agent/pkg/chat" +) + +// Stats aggregates high-level interaction and usage metrics for a session. +// It is derived data computed by walking the session and its sub-sessions; it +// is never persisted and carries no billing authority. It backs end-of-run +// summaries such as the non-interactive (--exec) recap. +type Stats struct { + // ID is the session identifier. + ID string + + // Duration is the wall-clock span between the first and last message, + // approximated from message timestamps (see Session.Duration). + Duration time.Duration + + // Requests is the number of assistant responses that reported usage. + Requests int + + // ToolCalls is the number of completed tool calls (tool-role results). + ToolCalls int + + // ToolErrors is the number of completed tool calls that failed. + ToolErrors int + + // InputTokens counts new (uncached) input tokens. + InputTokens int64 + + // CachedInput counts input tokens served from cache (cache reads). + CachedInput int64 + + // CacheWrite counts input tokens written to the cache. + CacheWrite int64 + + // OutputTokens counts generated output tokens. + OutputTokens int64 + + // ReasoningTokens counts reasoning tokens (subset of output for some models). + ReasoningTokens int64 + + // Cost is the total session cost in dollars. + Cost float64 + + // Models holds the per-model breakdown, sorted by cost descending. + Models []ModelStats +} + +// ModelStats holds the per-model usage breakdown within a session. +type ModelStats struct { + Model string + Requests int + InputTokens int64 // new (uncached) input tokens + CachedInput int64 // cache reads + OutputTokens int64 + Cost float64 +} + +// ToolSuccesses returns the number of completed tool calls that succeeded. +func (s Stats) ToolSuccesses() int { return s.ToolCalls - s.ToolErrors } + +// SuccessRate returns the tool-call success rate as a percentage in [0,100]. +// It returns 0 when no tool calls were made. +func (s Stats) SuccessRate() float64 { + if s.ToolCalls == 0 { + return 0 + } + return float64(s.ToolCalls-s.ToolErrors) / float64(s.ToolCalls) * 100 +} + +// TotalInput returns all input tokens, including cache reads and writes. +func (s Stats) TotalInput() int64 { return s.InputTokens + s.CachedInput + s.CacheWrite } + +// CacheHitRate returns the share of input tokens served from cache as a +// percentage in [0,100]. It returns 0 when there is no countable input. +func (s Stats) CacheHitRate() float64 { + candidate := s.InputTokens + s.CachedInput + if candidate == 0 { + return 0 + } + return float64(s.CachedInput) / float64(candidate) * 100 +} + +// HasActivity reports whether the session did anything worth summarizing. +func (s Stats) HasActivity() bool { + return s.Requests > 0 || s.ToolCalls > 0 || s.Cost > 0 +} + +// Stats computes a Stats snapshot for the session. It walks sub-sessions so +// that delegated work (task transfers) is included in the totals, mirroring +// TotalCost's accounting. +func (s *Session) Stats() Stats { + // Duration takes its own lock; compute it before walking so we never + // hold a read lock across the call. + st := Stats{ + ID: s.ID, + Duration: s.Duration(), + } + + modelIndex := make(map[string]int) + + var walk func(sess *Session) + walk = func(sess *Session) { + for _, item := range sess.snapshotItems() { + switch { + case item.IsMessage(): + m := item.Message.Message + st.Cost += m.Cost + switch m.Role { + case chat.MessageRoleTool: + st.ToolCalls++ + if m.IsError { + st.ToolErrors++ + } + case chat.MessageRoleAssistant: + if m.Usage != nil { + st.Requests++ + st.InputTokens += m.Usage.InputTokens + st.CachedInput += m.Usage.CachedInputTokens + st.CacheWrite += m.Usage.CacheWriteTokens + st.OutputTokens += m.Usage.OutputTokens + st.ReasoningTokens += m.Usage.ReasoningTokens + + model := cmp.Or(m.Model, "unknown") + idx, ok := modelIndex[model] + if !ok { + idx = len(st.Models) + modelIndex[model] = idx + st.Models = append(st.Models, ModelStats{Model: model}) + } + ms := &st.Models[idx] + ms.Requests++ + ms.InputTokens += m.Usage.InputTokens + ms.CachedInput += m.Usage.CachedInputTokens + ms.OutputTokens += m.Usage.OutputTokens + ms.Cost += m.Cost + } + } + case item.IsSubSession(): + walk(item.SubSession) + } + // Item-level cost (e.g. compaction/summarization) lives outside any + // message, so it is added once per item, matching TotalCost. + st.Cost += item.Cost + } + } + walk(s) + + // Remote mode keeps per-message usage in MessageUsageHistory rather than in + // the message list, so fall back to it when the walk found no usage. + if st.Requests == 0 && len(s.MessageUsageHistory) > 0 { + for _, r := range s.MessageUsageHistory { + st.Requests++ + st.InputTokens += r.Usage.InputTokens + st.CachedInput += r.Usage.CachedInputTokens + st.CacheWrite += r.Usage.CacheWriteTokens + st.OutputTokens += r.Usage.OutputTokens + st.ReasoningTokens += r.Usage.ReasoningTokens + st.Cost += r.Cost + + model := cmp.Or(r.Model, "unknown") + idx, ok := modelIndex[model] + if !ok { + idx = len(st.Models) + modelIndex[model] = idx + st.Models = append(st.Models, ModelStats{Model: model}) + } + ms := &st.Models[idx] + ms.Requests++ + ms.InputTokens += r.Usage.InputTokens + ms.CachedInput += r.Usage.CachedInputTokens + ms.OutputTokens += r.Usage.OutputTokens + ms.Cost += r.Cost + } + } + + slices.SortFunc(st.Models, func(a, b ModelStats) int { + if c := cmp.Compare(b.Cost, a.Cost); c != 0 { + return c + } + return cmp.Compare(a.Model, b.Model) + }) + + return st +} diff --git a/pkg/session/stats_test.go b/pkg/session/stats_test.go new file mode 100644 index 0000000000..3d9f7c2505 --- /dev/null +++ b/pkg/session/stats_test.go @@ -0,0 +1,158 @@ +package session + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/docker/docker-agent/pkg/chat" +) + +func assistantMsg(model string, cost float64, usage chat.Usage, createdAt string) *Message { + return &Message{ + AgentName: "root", + Message: chat.Message{ + Role: chat.MessageRoleAssistant, + Content: "ok", + Model: model, + Usage: &usage, + Cost: cost, + CreatedAt: createdAt, + }, + } +} + +func toolMsg(isError bool, createdAt string) *Message { + return &Message{ + Message: chat.Message{ + Role: chat.MessageRoleTool, + Content: "result", + ToolCallID: "call_1", + IsError: isError, + CreatedAt: createdAt, + }, + } +} + +func TestStats_EmptySession(t *testing.T) { + t.Parallel() + + st := New(WithID("sess-empty")).Stats() + + assert.Equal(t, "sess-empty", st.ID) + assert.False(t, st.HasActivity()) + assert.Zero(t, st.Requests) + assert.Zero(t, st.ToolCalls) + assert.Zero(t, st.Cost) + assert.Empty(t, st.Models) + assert.Zero(t, st.SuccessRate()) + assert.Zero(t, st.CacheHitRate()) +} + +func TestStats_AggregatesTokensToolsAndModels(t *testing.T) { + t.Parallel() + + sess := New(WithID("sess-1")) + sess.AddMessage(UserMessage("hi")) + sess.AddMessage(assistantMsg("cheap-model", 0.001, chat.Usage{ + InputTokens: 1000, + OutputTokens: 200, + }, "")) + sess.AddMessage(toolMsg(false, "")) + sess.AddMessage(toolMsg(true, "")) + sess.AddMessage(assistantMsg("pricey-model", 0.05, chat.Usage{ + InputTokens: 500, + CachedInputTokens: 1500, + CacheWriteTokens: 100, + OutputTokens: 300, + ReasoningTokens: 80, + }, "")) + + st := sess.Stats() + + assert.True(t, st.HasActivity()) + assert.Equal(t, 2, st.Requests) + assert.Equal(t, 2, st.ToolCalls) + assert.Equal(t, 1, st.ToolErrors) + assert.Equal(t, 1, st.ToolSuccesses()) + assert.InDelta(t, 50.0, st.SuccessRate(), 0.001) + + assert.Equal(t, int64(1500), st.InputTokens) + assert.Equal(t, int64(1500), st.CachedInput) + assert.Equal(t, int64(100), st.CacheWrite) + assert.Equal(t, int64(500), st.OutputTokens) + assert.Equal(t, int64(80), st.ReasoningTokens) + assert.Equal(t, int64(3100), st.TotalInput()) + assert.InDelta(t, 0.051, st.Cost, 1e-9) + + // 1500 cached out of (1500 new + 1500 cached) = 50%. + assert.InDelta(t, 50.0, st.CacheHitRate(), 0.001) + + // Models sorted by cost descending. + assert.Len(t, st.Models, 2) + assert.Equal(t, "pricey-model", st.Models[0].Model) + assert.Equal(t, "cheap-model", st.Models[1].Model) + assert.Equal(t, 1, st.Models[0].Requests) + assert.Equal(t, int64(1500), st.Models[0].CachedInput) +} + +func TestStats_IncludesSubSessions(t *testing.T) { + t.Parallel() + + sub := New(WithID("sub")) + sub.AddMessage(assistantMsg("sub-model", 0.02, chat.Usage{ + InputTokens: 400, + OutputTokens: 50, + }, "")) + sub.AddMessage(toolMsg(false, "")) + + parent := New(WithID("parent")) + parent.AddMessage(assistantMsg("main-model", 0.01, chat.Usage{ + InputTokens: 100, + OutputTokens: 20, + }, "")) + parent.AddSubSession(sub) + + st := parent.Stats() + + assert.Equal(t, 2, st.Requests) + assert.Equal(t, 1, st.ToolCalls) + assert.Equal(t, int64(500), st.InputTokens) + assert.InDelta(t, 0.03, st.Cost, 1e-9) + assert.Len(t, st.Models, 2) +} + +func TestStats_Duration(t *testing.T) { + t.Parallel() + + t0 := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) + sess := New(WithID("sess-dur")) + sess.AddMessage(assistantMsg("m", 0.001, chat.Usage{InputTokens: 10, OutputTokens: 5}, t0.Format(time.RFC3339))) + sess.AddMessage(assistantMsg("m", 0.001, chat.Usage{InputTokens: 10, OutputTokens: 5}, t0.Add(90*time.Second).Format(time.RFC3339))) + + st := sess.Stats() + + assert.Equal(t, 90*time.Second, st.Duration) +} + +func TestStats_RemoteFallback(t *testing.T) { + t.Parallel() + + sess := New(WithID("sess-remote")) + sess.AddMessageUsageRecord("root", "remote-model", 0.04, &chat.Usage{ + InputTokens: 200, + CachedInputTokens: 600, + OutputTokens: 90, + }) + + st := sess.Stats() + + assert.True(t, st.HasActivity()) + assert.Equal(t, 1, st.Requests) + assert.Equal(t, int64(200), st.InputTokens) + assert.Equal(t, int64(600), st.CachedInput) + assert.InDelta(t, 0.04, st.Cost, 1e-9) + assert.Len(t, st.Models, 1) + assert.Equal(t, "remote-model", st.Models[0].Model) +}