Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cmd/root/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ func (f *runExecFlags) runRunCommand(cmd *cobra.Command, args []string) (command

out := cli.NewPrinter(cmd.OutOrStdout())

return f.runOrExec(ctx, out, args, useTUI)
return f.runOrExec(ctx, out, cmd.ErrOrStderr(), args, useTUI)
}

func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []string, useTUI bool) error {
func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, errOut io.Writer, args []string, useTUI bool) error {
slog.DebugContext(ctx, "Starting agent", "agent", f.agentName)

// Start profiling if requested
Expand Down Expand Up @@ -436,7 +436,7 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
// Non-interactive (--exec) runs never clean up the worktree: there
// is no safe moment to prompt, and silently discarding work would
// be surprising. The worktree is left in place for later inspection.
return f.handleExecMode(ctx, out, rt, sess, args)
return f.handleExecMode(ctx, out, errOut, rt, sess, args)
}

listenOpt, err := f.startAttachedServer(ctx, out, rt, sess)
Expand Down Expand Up @@ -823,7 +823,7 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes
return localRt, sess, nil
}

func (f *runExecFlags) handleExecMode(ctx context.Context, out *cli.Printer, rt runtime.Runtime, sess *session.Session, args []string) error {
func (f *runExecFlags) handleExecMode(ctx context.Context, out *cli.Printer, errOut io.Writer, rt runtime.Runtime, sess *session.Session, args []string) error {
if f.sessionReadOnly {
return errors.New("--session-read-only cannot be used with --exec: there is nothing to display without a TUI")
}
Expand All @@ -840,6 +840,7 @@ func (f *runExecFlags) handleExecMode(ctx context.Context, out *cli.Printer, rt
HideToolCalls: f.hideToolCalls,
OutputJSON: f.outputJSON,
AutoApprove: f.autoApprove,
SummaryOut: errOut,
}, rt, sess, userMessages)
if cliErr, ok := errors.AsType[cli.RuntimeError](err); ok {
return RuntimeError{Err: cliErr.Err}
Expand Down
14 changes: 14 additions & 0 deletions pkg/cli/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ type Config struct {
AutoApprove bool
HideToolCalls bool
OutputJSON bool

// SummaryOut, when non-nil, receives the end-of-run recap (interaction
// summary, model usage, token/cost totals). It is kept separate from the
// agent's response stream so callers can route it to stderr and keep
// stdout machine-consumable. The recap is skipped in JSON output mode.
SummaryOut io.Writer
}

// Run executes an agent in non-TUI mode, handling user input and runtime events.
Expand Down Expand Up @@ -311,6 +317,14 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
}
}

// Print an end-of-run recap (interaction summary, model usage, token/cost
// totals) to the dedicated summary writer (typically stderr) so stdout stays
// machine-consumable. JSON mode emits machine-readable events instead, so it
// is skipped there.
if !cfg.OutputJSON && cfg.SummaryOut != nil {
NewPrinter(cfg.SummaryOut).PrintSessionSummary(sess.Stats())
}

// Wrap runtime errors to prevent duplicate error messages and usage display
if lastErr != nil {
return RuntimeError{Err: lastErr}
Expand Down
182 changes: 182 additions & 0 deletions pkg/cli/summary.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package cli

import (
"fmt"
"strconv"
"strings"
"time"

"github.com/docker/docker-agent/pkg/session"
)

// PrintSessionSummary writes a compact end-of-run recap: an interaction
// summary (duration, tool calls, success rate), a per-model usage table, and
// token/cost totals with cache savings. It is a no-op when the session
// recorded no measurable activity. Callers skip it in JSON output mode.
func (p *Printer) PrintSessionSummary(stats session.Stats) {
if !stats.HasActivity() {
return
}

p.Println()
p.Println(bold("─── Session summary ───"))

var rows [][2]string
if stats.ID != "" {
rows = append(rows, [2]string{"Session", stats.ID})
}
if stats.Duration > 0 {
rows = append(rows, [2]string{"Duration", formatDuration(stats.Duration)})
}
if stats.Requests > 0 {
rows = append(rows, [2]string{"Requests", strconv.Itoa(stats.Requests)})
}
if stats.ToolCalls > 0 {
rows = append(rows,
[2]string{"Tool calls", fmt.Sprintf("%d (%d ok, %d failed)", stats.ToolCalls, stats.ToolSuccesses(), stats.ToolErrors)},
[2]string{"Success rate", fmt.Sprintf("%.1f%%", stats.SuccessRate())},
)
}
p.printLabeledRows(rows)

if len(stats.Models) > 0 {
p.Println()
p.Println(bold("Model usage"))
p.printModelTable(stats.Models)
}

p.Println()
totals := [][2]string{
{"Tokens", fmt.Sprintf("%s in / %s out", groupDigits(stats.TotalInput()), groupDigits(stats.OutputTokens))},
{"Cost", formatCost(stats.Cost)},
}
if stats.CachedInput > 0 {
totals = append(totals, [2]string{"Cache", fmt.Sprintf("%s input tokens from cache (%.1f%%)", groupDigits(stats.CachedInput), stats.CacheHitRate())})
}
p.printLabeledRows(totals)
}

// printLabeledRows prints "label: value" pairs with the colon-terminated
// labels padded to a common width so the values line up. Padding is applied
// outside the color escape so it stays correct regardless of TTY coloring.
func (p *Printer) printLabeledRows(rows [][2]string) {
labelWidth := 0
for _, r := range rows {
if l := len(r[0]) + 1; l > labelWidth {
labelWidth = l
}
}
for _, r := range rows {
label := r[0] + ":"
p.Printf("%s%s %s\n", bold(label), strings.Repeat(" ", labelWidth-len(label)), r[1])
}
}

// printModelTable renders the per-model usage breakdown as an aligned table.
func (p *Printer) printModelTable(models []session.ModelStats) {
headers := []string{"MODEL", "REQS", "INPUT", "CACHE READ", "OUTPUT"}
rows := make([][]string, 0, len(models))
for _, m := range models {
rows = append(rows, []string{
m.Model,
strconv.Itoa(m.Requests),
groupDigits(m.InputTokens),
groupDigits(m.CachedInput),
groupDigits(m.OutputTokens),
})
}

widths := make([]int, len(headers))
for i, h := range headers {
widths[i] = len(h)
}
for _, row := range rows {
for i, c := range row {
if len(c) > widths[i] {
widths[i] = len(c)
}
}
}

p.Println(bold(formatTableRow(headers, widths)))
for _, row := range rows {
p.Println(formatTableRow(row, widths))
}
}

// formatTableRow left-aligns the first column (model name) and right-aligns the
// numeric columns, separating them with a fixed gutter.
func formatTableRow(cols []string, widths []int) string {
var b strings.Builder
for i, c := range cols {
if i > 0 {
b.WriteString(" ")
}
pad := strings.Repeat(" ", widths[i]-len(c))
if i == 0 {
b.WriteString(c)
b.WriteString(pad)
} else {
b.WriteString(pad)
b.WriteString(c)
}
}
return strings.TrimRight(b.String(), " ")
}

// groupDigits formats an integer with thousands separators (e.g. 70637 ->
// "70,637").
func groupDigits(n int64) string {
s := strconv.FormatInt(n, 10)
neg := strings.HasPrefix(s, "-")
if neg {
s = s[1:]
}
for i := len(s) - 3; i > 0; i -= 3 {
s = s[:i] + "," + s[i:]
}
if neg {
return "-" + s
}
return s
}

// formatCost renders a dollar amount, widening the precision for sub-cent
// values so small costs are not all displayed as "$0.00".
func formatCost(cost float64) string {
if cost < 0 {
return "-" + formatCost(-cost)
}
switch {
case cost < 0.0001:
return "$0.00"
case cost < 0.01:
return fmt.Sprintf("$%.4f", cost)
default:
return fmt.Sprintf("$%.2f", cost)
}
}

// formatDuration renders a duration in a compact, human-readable form.
func formatDuration(d time.Duration) string {
if d <= 0 {
return "0s"
}
if d < time.Minute {
return fmt.Sprintf("%ds", int(d.Seconds()))
}
if d < time.Hour {
m := int(d.Minutes())
s := int(d.Seconds()) % 60
if s == 0 {
return fmt.Sprintf("%dm", m)
}
return fmt.Sprintf("%dm %ds", m, s)
}
h := int(d.Hours())
m := int(d.Minutes()) % 60
if m == 0 {
return fmt.Sprintf("%dh", h)
}
return fmt.Sprintf("%dh %dm", h, m)
}
115 changes: 115 additions & 0 deletions pkg/cli/summary_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package cli

import (
"bytes"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/docker/docker-agent/pkg/session"
)

func renderSummary(stats session.Stats) string {
var buf bytes.Buffer
NewPrinter(&buf).PrintSessionSummary(stats)
return buf.String()
}

func TestPrintSessionSummary_NoActivityIsNoOp(t *testing.T) {
t.Parallel()

assert.Empty(t, renderSummary(session.Stats{ID: "abc"}))
}

func TestPrintSessionSummary_FullRecap(t *testing.T) {
t.Parallel()

out := renderSummary(session.Stats{
ID: "abc-123",
Duration: 90 * time.Second,
Requests: 4,
ToolCalls: 12,
ToolErrors: 2,
InputTokens: 1000,
CachedInput: 146112,
OutputTokens: 2059,
Cost: 0.1234,
Models: []session.ModelStats{
{Model: "gemini-2.5-pro", Requests: 12, InputTokens: 22264, CachedInput: 106844, OutputTokens: 237},
},
})

// Interaction summary.
assert.Contains(t, out, "Session summary")
assert.Contains(t, out, "abc-123")
assert.Contains(t, out, "1m 30s")
assert.Contains(t, out, "12 (10 ok, 2 failed)")
assert.Contains(t, out, "83.3%")

// Model usage table.
assert.Contains(t, out, "Model usage")
assert.Contains(t, out, "gemini-2.5-pro")
assert.Contains(t, out, "22,264")
assert.Contains(t, out, "106,844")

// Totals and cache savings.
assert.Contains(t, out, "147,112 in / 2,059 out")
assert.Contains(t, out, "$0.12")
assert.Contains(t, out, "146,112 input tokens from cache")
}

func TestPrintSessionSummary_NoToolCallsOmitsToolLines(t *testing.T) {
t.Parallel()

out := renderSummary(session.Stats{
ID: "no-tools",
Requests: 1,
InputTokens: 500,
OutputTokens: 100,
Cost: 0.002,
Models: []session.ModelStats{{Model: "m", Requests: 1, InputTokens: 500, OutputTokens: 100}},
})

assert.NotContains(t, out, "Tool calls")
assert.NotContains(t, out, "Success rate")
assert.Contains(t, out, "500 in / 100 out")
}

func TestGroupDigits(t *testing.T) {
t.Parallel()

cases := map[int64]string{
0: "0",
5: "5",
999: "999",
1000: "1,000",
70637: "70,637",
1234567: "1,234,567",
-1500: "-1,500",
}
for in, want := range cases {
assert.Equalf(t, want, groupDigits(in), "groupDigits(%d)", in)
}
}

func TestFormatCost(t *testing.T) {
t.Parallel()

assert.Equal(t, "$0.00", formatCost(0))
assert.Equal(t, "$0.00", formatCost(0.00005))
assert.Equal(t, "$0.0050", formatCost(0.005))
assert.Equal(t, "$0.12", formatCost(0.1234))
assert.Equal(t, "$3.00", formatCost(3))
}

func TestFormatDuration(t *testing.T) {
t.Parallel()

assert.Equal(t, "0s", formatDuration(0))
assert.Equal(t, "45s", formatDuration(45*time.Second))
assert.Equal(t, "1m", formatDuration(time.Minute))
assert.Equal(t, "1m 30s", formatDuration(90*time.Second))
assert.Equal(t, "1h", formatDuration(time.Hour))
assert.Equal(t, "1h 1m", formatDuration(time.Hour+time.Minute))
}
Loading
Loading