ajhahn.de
← eeco
Go 131 lines
package ai

import (
	"encoding/json"
	"os"
	"path/filepath"
	"strings"
	"testing"
)

func readLedger(t *testing.T, stateDir string) aiCallLedger {
	t.Helper()
	b, err := os.ReadFile(filepath.Join(stateDir, AICallsFilename))
	if err != nil {
		t.Fatalf("read ledger: %v", err)
	}
	var l aiCallLedger
	if err := json.Unmarshal(b, &l); err != nil {
		t.Fatalf("decode ledger: %v", err)
	}
	return l
}

func TestRecordCall_RanAndParkedAppend(t *testing.T) {
	state := t.TempDir()
	g := &Gate{StateDir: state}

	// A ran pass with usage + a response.
	g.recordCall(
		Request{Label: "evolve", System: "S", User: "U"},
		"anthropic", "claude-haiku-4-5", "the answer",
		true, false, "", Usage{InputTokens: 10, CachedInputTokens: 4, OutputTokens: 6}, nil)
	// A parked pass: no response text, a reason, zero usage.
	g.recordCall(
		Request{Label: "evolve", System: "S", User: "U"},
		"none", "", "",
		false, true, "AI budget exhausted (cap 0)", Usage{}, nil)

	l := readLedger(t, state)
	if len(l.Records) != 2 {
		t.Fatalf("want 2 records, got %d", len(l.Records))
	}

	ran := l.Records[0]
	if !ran.Ran || ran.Parked {
		t.Errorf("first record: want ran, got %+v", ran)
	}
	if ran.PromptSHA256 != sha256Hex("S\n\nU") {
		t.Errorf("prompt hash = %q, want hash of folded prompt", ran.PromptSHA256)
	}
	if ran.ResponseSHA256 != sha256Hex("the answer") {
		t.Errorf("response hash = %q, want hash of response text", ran.ResponseSHA256)
	}
	if ran.Provider != "anthropic" || ran.Model != "claude-haiku-4-5" {
		t.Errorf("provider/model not recorded: %+v", ran)
	}
	if ran.Tokens != (aiCallTokens{Input: 10, CachedInput: 4, Output: 6}) {
		t.Errorf("tokens = %+v, want input=10 cached=4 output=6", ran.Tokens)
	}

	parked := l.Records[1]
	if parked.Ran || !parked.Parked {
		t.Errorf("second record: want parked, got %+v", parked)
	}
	if parked.ResponseSHA256 != "" {
		t.Errorf("parked record must carry no response hash, got %q", parked.ResponseSHA256)
	}
	if parked.ParkReason == "" {
		t.Error("parked record must carry the park reason")
	}
}

func TestRecordCall_NoStateDirIsNoop(t *testing.T) {
	g := &Gate{StateDir: ""}
	// Must not panic and must not write anywhere.
	g.recordCall(Request{Label: "x", User: "p"}, "none", "", "", false, true, "r", Usage{}, nil)
}

func TestAppendAICall_CorruptFileResets(t *testing.T) {
	state := t.TempDir()
	if err := os.WriteFile(filepath.Join(state, AICallsFilename), []byte("{not json"), 0o644); err != nil {
		t.Fatal(err)
	}
	// A corrupt ledger must not wedge the append: it degrades to empty,
	// then the new record lands as the sole entry.
	if err := appendAICall(state, aiCallRecord{Label: "fresh", Provider: "none", Ran: false, Parked: true}); err != nil {
		t.Fatalf("append over corrupt file: %v", err)
	}
	l := readLedger(t, state)
	if len(l.Records) != 1 || l.Records[0].Label != "fresh" {
		t.Errorf("corrupt file must reset to a single fresh record, got %+v", l.Records)
	}
}

// TestRecordCall_NoToolsOmitsField is the back-compat firewall: a record
// written with nil tools (every pre-v1.7.0 caller) must not emit a "tools"
// key, so older ledgers and the five existing call sites round-trip
// byte-identically.
func TestRecordCall_NoToolsOmitsField(t *testing.T) {
	state := t.TempDir()
	g := &Gate{StateDir: state}
	g.recordCall(Request{Label: "evolve", User: "U"}, "anthropic", "m", "ans",
		true, false, "", Usage{}, nil)

	b, err := os.ReadFile(filepath.Join(state, AICallsFilename))
	if err != nil {
		t.Fatal(err)
	}
	if strings.Contains(string(b), "\"tools\"") {
		t.Errorf("nil tools must not emit a tools key; ledger = %s", b)
	}
}

// TestRecordCall_WithToolsSerializes proves a tool-using round records the
// invoked tool names under the additive "tools" field.
func TestRecordCall_WithToolsSerializes(t *testing.T) {
	state := t.TempDir()
	g := &Gate{StateDir: state}
	g.recordCall(Request{Label: "tui-request", User: "U"}, "anthropic", "m", "ans",
		true, false, "", Usage{}, []string{"search_knowledge", "project_brief"})

	l := readLedger(t, state)
	if len(l.Records) != 1 {
		t.Fatalf("want 1 record, got %d", len(l.Records))
	}
	got := l.Records[0].Tools
	if len(got) != 2 || got[0] != "search_knowledge" || got[1] != "project_brief" {
		t.Errorf("tools = %v, want [search_knowledge project_brief]", got)
	}
}