ajhahn.de
← eeco
Go 370 lines
package ai

import (
	"context"
	"errors"
	"os"
	"path/filepath"
	"strings"
	"testing"

	"github.com/ajhahnde/eeco/internal/config"
)

// fakeProvider counts calls and returns a scripted result.
type fakeProvider struct {
	calls int
	text  string
	err   error
}

func (f *fakeProvider) Name() string { return "fake" }
func (f *fakeProvider) Run(context.Context, Request) (Response, error) {
	f.calls++
	return Response{Text: f.text}, f.err
}

// usageProvider returns a scripted Response carrying token usage so the
// Gate's usage-threading can be exercised.
type usageProvider struct {
	text  string
	usage Usage
}

func (usageProvider) Name() string { return "usage" }
func (p usageProvider) Run(context.Context, Request) (Response, error) {
	return Response{Text: p.text, Usage: p.usage}, nil
}

func newGate(t *testing.T, p Provider, consent bool, budget int) (*Gate, string) {
	t.Helper()
	state := filepath.Join(t.TempDir(), "state")
	if err := os.MkdirAll(state, 0o755); err != nil {
		t.Fatal(err)
	}
	return &Gate{Provider: p, Consent: consent, Budget: budget, StateDir: state, Project: "proj"}, state
}

func assertParked(t *testing.T, state string, out Outcome) {
	t.Helper()
	if !out.Skipped || out.Ran {
		t.Fatalf("want Skipped, got %+v", out)
	}
	if out.Parked == "" {
		t.Fatal("Skipped outcome must record a parked-prompt path")
	}
	if _, err := os.Stat(out.Parked); err != nil {
		t.Fatalf("parked file missing: %v", err)
	}
	q, err := os.ReadFile(filepath.Join(state, "queue.md"))
	if err != nil {
		t.Fatalf("queue not written: %v", err)
	}
	if !strings.Contains(string(q), "ai-parked") {
		t.Errorf("queue missing ai-parked item:\n%s", q)
	}
}

func TestGate_NoConsentParksWithoutSpending(t *testing.T) {
	fp := &fakeProvider{text: "result"}
	g, state := newGate(t, fp, false, 5)
	out, err := g.Run(context.Background(), Request{Label: "unit", User: "the prompt"})
	if err != nil {
		t.Fatal(err)
	}
	if fp.calls != 0 {
		t.Errorf("provider called %d times without consent; want 0", fp.calls)
	}
	assertParked(t, state, out)
}

func TestGate_BudgetExhaustedParks(t *testing.T) {
	fp := &fakeProvider{text: "ok"}
	g, state := newGate(t, fp, true, 1)
	if out, err := g.Run(context.Background(), Request{Label: "a", User: "p1"}); err != nil || !out.Ran {
		t.Fatalf("first call should run: out=%+v err=%v", out, err)
	}
	out, err := g.Run(context.Background(), Request{Label: "b", User: "p2"})
	if err != nil {
		t.Fatal(err)
	}
	if fp.calls != 1 {
		t.Errorf("provider called %d times; budget 1 must cap at 1", fp.calls)
	}
	assertParked(t, state, out)
}

func TestGate_ZeroBudgetParks(t *testing.T) {
	fp := &fakeProvider{text: "ok"}
	g, state := newGate(t, fp, true, 0)
	out, err := g.Run(context.Background(), Request{Label: "z", User: "p"})
	if err != nil {
		t.Fatal(err)
	}
	if fp.calls != 0 {
		t.Errorf("zero budget must not spend; calls=%d", fp.calls)
	}
	assertParked(t, state, out)
}

func TestGate_ProviderErrorParksNotFatal(t *testing.T) {
	fp := &fakeProvider{err: errors.New("boom")}
	g, state := newGate(t, fp, true, 3)
	out, err := g.Run(context.Background(), Request{Label: "e", User: "p"})
	if err != nil {
		t.Fatalf("provider failure must not be a hard error: %v", err)
	}
	assertParked(t, state, out)
	if !strings.Contains(out.Reason, "boom") {
		t.Errorf("reason should carry provider error, got %q", out.Reason)
	}
}

func TestGate_EmptyResultParks(t *testing.T) {
	fp := &fakeProvider{text: "   "}
	g, state := newGate(t, fp, true, 3)
	out, err := g.Run(context.Background(), Request{Label: "blank", User: "p"})
	if err != nil {
		t.Fatal(err)
	}
	assertParked(t, state, out)
}

func TestGate_SuccessReturnsText(t *testing.T) {
	fp := &fakeProvider{text: "  the answer  "}
	g, _ := newGate(t, fp, true, 1)
	out, err := g.Run(context.Background(), Request{Label: "ok", User: "p"})
	if err != nil {
		t.Fatal(err)
	}
	if !out.Ran || out.Skipped {
		t.Fatalf("want Ran, got %+v", out)
	}
	if out.Text != "  the answer  " {
		t.Errorf("Text = %q (gate must not mangle provider text)", out.Text)
	}
}

func TestSelect(t *testing.T) {
	// After C5 the provider set is {cli, none}: a configured `ai_command`
	// picks the CLI provider, everything else parks. The legacy
	// `ai_provider=anthropic` is tolerated and behaves exactly like auto
	// (the in-binary API provider was retired).
	cmd := []string{"echo", "hi"}
	tests := []struct {
		name string
		cfg  *config.Config
		want string // expected provider Name()
	}{
		{"nil cfg", nil, "none"},
		{"auto, no command", &config.Config{}, "none"},
		{"auto, command picks cli", &config.Config{AICommand: cmd}, "cli"},
		{"explicit cli with command", &config.Config{AIProvider: "cli", AICommand: cmd}, "cli"},
		{"explicit cli without command", &config.Config{AIProvider: "cli"}, "none"},
		{"legacy anthropic, no command, falls to none", &config.Config{AIProvider: "anthropic"}, "none"},
		{"legacy anthropic with command, falls to cli", &config.Config{AIProvider: "anthropic", AICommand: cmd}, "cli"},
		{"unknown provider with command falls back to cli", &config.Config{AIProvider: "bogus", AICommand: cmd}, "cli"},
	}
	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			if got := Select(tc.cfg).Name(); got != tc.want {
				t.Errorf("Select() = %q, want %q", got, tc.want)
			}
		})
	}
}

func TestNotConfigured_RunReportsSentinel(t *testing.T) {
	_, err := notConfigured{}.Run(context.Background(), Request{})
	if !errors.Is(err, ErrNotConfigured) {
		t.Errorf("err = %v, want ErrNotConfigured", err)
	}
}

func TestCLIProvider_RunsConfiguredCommand(t *testing.T) {
	if _, err := os.Stat("/bin/sh"); err != nil {
		t.Skip("no /bin/sh")
	}
	// Echo stdin back: proves the folded prompt is fed in and stdout is
	// the result. An empty System folds to exactly User.
	p := cliProvider{argv: []string{"/bin/sh", "-c", "cat"}}
	got, err := p.Run(context.Background(), Request{User: "hello prompt"})
	if err != nil {
		t.Fatal(err)
	}
	if got.Text != "hello prompt" {
		t.Errorf("got %q, want %q", got.Text, "hello prompt")
	}
}

func TestCLIProvider_FoldsSystemAndUser(t *testing.T) {
	if _, err := os.Stat("/bin/sh"); err != nil {
		t.Skip("no /bin/sh")
	}
	p := cliProvider{argv: []string{"/bin/sh", "-c", "cat"}}
	got, err := p.Run(context.Background(), Request{System: "S", User: "U"})
	if err != nil {
		t.Fatal(err)
	}
	if got.Text != "S\n\nU" {
		t.Errorf("folded prompt = %q, want %q", got.Text, "S\n\nU")
	}
}

func TestFoldPrompt(t *testing.T) {
	if got := foldPrompt(Request{User: "only"}); got != "only" {
		t.Errorf("User-only fold = %q, want %q", got, "only")
	}
	if got := foldPrompt(Request{System: "S", User: "U"}); got != "S\n\nU" {
		t.Errorf("System+User fold = %q, want %q", got, "S\n\nU")
	}
	// An empty Messages must fall through to the single-turn branch
	// byte-for-byte, so the four single-turn callers stay unchanged.
	if got := foldPrompt(Request{System: "S", User: "U", Messages: nil}); got != "S\n\nU" {
		t.Errorf("empty-Messages fold = %q, want byte-identical %q", got, "S\n\nU")
	}
}

func TestFoldPrompt_Transcript(t *testing.T) {
	req := Request{
		System: "SYS",
		User:   "ignored when Messages is set",
		Messages: []Message{
			{Role: "user", Text: "hello"},
			{Role: "assistant", Text: "hi there"},
			{Role: "user", Text: "more"},
		},
	}
	want := "SYS\n\nUser: hello\n\nAssistant: hi there\n\nUser: more"
	if got := foldPrompt(req); got != want {
		t.Errorf("transcript fold = %q, want %q", got, want)
	}
	// Without a System block the transcript leads with the first turn.
	noSys := Request{Messages: []Message{{Role: "user", Text: "just me"}}}
	if got := foldPrompt(noSys); got != "User: just me" {
		t.Errorf("system-less transcript = %q, want %q", got, "User: just me")
	}
}

func TestCLIProvider_EmptyArgvIsNotConfigured(t *testing.T) {
	_, err := cliProvider{}.Run(context.Background(), Request{})
	if !errors.Is(err, ErrNotConfigured) {
		t.Errorf("err = %v, want ErrNotConfigured", err)
	}
}

func TestGate_ThreadsUsageOnRan(t *testing.T) {
	// The Gate must surface the provider's token accounting on a ran pass.
	fp := &usageProvider{text: "ok", usage: Usage{InputTokens: 12, CachedInputTokens: 3, OutputTokens: 7}}
	g, _ := newGate(t, fp, true, 1)
	out, err := g.Run(context.Background(), Request{Label: "u", User: "p"})
	if err != nil {
		t.Fatal(err)
	}
	if !out.Ran {
		t.Fatalf("want Ran, got %+v", out)
	}
	if out.Usage != (Usage{InputTokens: 12, CachedInputTokens: 3, OutputTokens: 7}) {
		t.Errorf("Usage = %+v, want it threaded from the provider", out.Usage)
	}
}

// fragCoAB assembles the trailer key from fragments so this tracked test
// file carries no contiguous attribution literal for eeco's own leak-guard
// to flag. The scanner under test is injected inline (a func value), never
// imported from internal/workflow — that import would cycle.
const fragCoAB = "Co-" + "Authored-" + "By"

func inlineAttributionScanner(s string) []string {
	if strings.Contains(s, fragCoAB) {
		return []string{"line 1: co-authored-by trailer"}
	}
	return nil
}

func TestGate_FilterBlocksAttributionAndRecordsHash(t *testing.T) {
	resp := fragCoAB + ": A Bot <b@x>\n"
	fp := &fakeProvider{text: resp}
	g, state := newGate(t, fp, true, 1)
	g.Scanner = inlineAttributionScanner

	out, err := g.Run(context.Background(), Request{Label: "evolve", User: "p"})
	if err != nil {
		t.Fatal(err)
	}
	assertParked(t, state, out)
	if fp.calls != 1 {
		t.Errorf("provider should have run once before the block; calls=%d", fp.calls)
	}
	if !strings.Contains(out.Reason, "attribution") {
		t.Errorf("park reason should name the attribution block, got %q", out.Reason)
	}
	if out.Text != "" {
		t.Errorf("blocked response text must never reach the caller, got %q", out.Text)
	}

	l := readLedger(t, state)
	if len(l.Records) != 1 {
		t.Fatalf("want 1 ledger record, got %d", len(l.Records))
	}
	rec := l.Records[0]
	if rec.Ran || !rec.Parked {
		t.Errorf("blocked pass must record ran=false parked=true, got %+v", rec)
	}
	if rec.ResponseSHA256 != sha256Hex(resp) {
		t.Errorf("blocked pass must record the response hash, got %q want %q", rec.ResponseSHA256, sha256Hex(resp))
	}
	if !strings.Contains(rec.ParkReason, "attribution") {
		t.Errorf("ledger park reason should name the attribution block, got %q", rec.ParkReason)
	}
}

func TestGate_FilterPassesCleanResponse(t *testing.T) {
	fp := &fakeProvider{text: "  a clean answer  "}
	g, state := newGate(t, fp, true, 1)
	g.Scanner = inlineAttributionScanner

	out, err := g.Run(context.Background(), Request{Label: "evolve", User: "p"})
	if err != nil {
		t.Fatal(err)
	}
	if !out.Ran || out.Skipped {
		t.Fatalf("clean response must pass, got %+v", out)
	}
	if out.Text != "  a clean answer  " {
		t.Errorf("Text = %q (filter must not mangle a clean response)", out.Text)
	}
	rec := readLedger(t, state).Records[0]
	if !rec.Ran || rec.Parked || rec.ResponseSHA256 == "" {
		t.Errorf("clean pass record = %+v, want ran with a response hash", rec)
	}
}

func TestGate_NilScannerSkipsFilter(t *testing.T) {
	// Attribution text with a nil Scanner must pass: the filter is nil-safe.
	fp := &fakeProvider{text: fragCoAB + ": A Bot <b@x>\n"}
	g, _ := newGate(t, fp, true, 1)
	out, err := g.Run(context.Background(), Request{Label: "evolve", User: "p"})
	if err != nil {
		t.Fatal(err)
	}
	if !out.Ran {
		t.Fatalf("nil scanner must not block; got %+v", out)
	}
}

func TestUnderstand_IsGated(t *testing.T) {
	// No consent: the background pass must park, never spend.
	fp := &fakeProvider{text: "summary"}
	g, state := newGate(t, fp, false, 5)
	cfg := &config.Config{RepoRoot: t.TempDir(), Profile: config.ProfileGo}
	out, err := Understand(context.Background(), g, cfg)
	if err != nil {
		t.Fatal(err)
	}
	if fp.calls != 0 {
		t.Errorf("Understand spent without consent (calls=%d)", fp.calls)
	}
	assertParked(t, state, out)
}