Go 277 lines
package projecttype
import (
"context"
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
// makeRepo builds a fake repo tree in a temp dir. An entry ending in "/"
// is created as a directory; anything else is an empty file.
func makeRepo(t *testing.T, entries ...string) string {
t.Helper()
root := t.TempDir()
for _, e := range entries {
p := filepath.Join(root, e)
if strings.HasSuffix(e, "/") {
if err := os.MkdirAll(p, 0o755); err != nil {
t.Fatalf("mkdir %s: %v", e, err)
}
continue
}
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
t.Fatalf("mkdir parent %s: %v", e, err)
}
if err := os.WriteFile(p, nil, 0o644); err != nil {
t.Fatalf("write %s: %v", e, err)
}
}
return root
}
func mustCatalog(t *testing.T) *Catalog {
t.Helper()
cat, err := LoadCatalog()
if err != nil {
t.Fatalf("LoadCatalog: %v", err)
}
return cat
}
func TestDetectDeterministicWinner(t *testing.T) {
cat := mustCatalog(t)
cases := []struct {
name string
entries []string
want Category
}{
{"go-cli", []string{"go.mod", "cmd/", "internal/"}, CLI},
{"terraform", []string{"main.tf", "modules/"}, Infra},
{"flutter", []string{"pubspec.yaml"}, Mobile},
{"gamedev", []string{"levels/", "assets/", "scenes/"}, GameDev},
{"ml", []string{"requirements.txt", "notebooks/", "data/"}, ML},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
root := makeRepo(t, tc.entries...)
got, err := Detect(context.Background(), cat, Options{RepoRoot: root})
if err != nil {
t.Fatalf("Detect: %v", err)
}
if got.Category != tc.want {
t.Fatalf("category = %q, want %q (conf %.3f)", got.Category, tc.want, got.Confidence)
}
if got.Source != SourceMarker {
t.Errorf("source = %q, want %q", got.Source, SourceMarker)
}
if got.Confidence < DefaultThreshold {
t.Errorf("confidence = %.3f, want >= %.2f", got.Confidence, DefaultThreshold)
}
if len(got.Dirs) == 0 {
t.Error("no dirs in result")
}
})
}
}
func TestDetectAmbiguousNonInteractiveTakesBestGuess(t *testing.T) {
cat := mustCatalog(t)
root := makeRepo(t, "package.json")
got, err := Detect(context.Background(), cat, Options{RepoRoot: root})
if err != nil {
t.Fatalf("Detect: %v", err)
}
if got.Category != WebApp {
t.Errorf("category = %q, want %q", got.Category, WebApp)
}
if got.Source != SourceMarker {
t.Errorf("source = %q, want %q", got.Source, SourceMarker)
}
if got.Confidence >= DefaultThreshold {
t.Errorf("confidence = %.3f, expected below threshold for an ambiguous tree", got.Confidence)
}
}
func TestDetectNoSignalFallsBackToGeneric(t *testing.T) {
cat := mustCatalog(t)
root := makeRepo(t, "README.md")
got, err := Detect(context.Background(), cat, Options{RepoRoot: root})
if err != nil {
t.Fatalf("Detect: %v", err)
}
if got.Category != Generic {
t.Errorf("category = %q, want generic", got.Category)
}
if got.Source != SourceFallback {
t.Errorf("source = %q, want %q", got.Source, SourceFallback)
}
}
func TestDetectForced(t *testing.T) {
cat := mustCatalog(t)
root := makeRepo(t, "go.mod", "cmd/")
got, err := Detect(context.Background(), cat, Options{RepoRoot: root, Forced: ML})
if err != nil {
t.Fatalf("Detect: %v", err)
}
if got.Category != ML || got.Source != SourceFlag || got.Confidence != 1.0 {
t.Fatalf("forced result = %+v", got)
}
if _, err := Detect(context.Background(), cat, Options{RepoRoot: root, Forced: Category("banana")}); err == nil {
t.Error("expected error for unknown forced type")
}
}
// fakePrompter returns a scripted choice for the layer-3 prompt.
type fakePrompter struct {
choice Category
describe bool
freeText string
gotCands []Category
}
func (f *fakePrompter) Pick(candidates []Category, _ *Catalog) (Category, bool, string, error) {
f.gotCands = candidates
return f.choice, f.describe, f.freeText, nil
}
func TestDetectInteractivePick(t *testing.T) {
cat := mustCatalog(t)
root := makeRepo(t, "package.json")
fp := &fakePrompter{choice: Fullstack}
got, err := Detect(context.Background(), cat, Options{RepoRoot: root, Prompter: fp})
if err != nil {
t.Fatalf("Detect: %v", err)
}
if got.Category != Fullstack || got.Source != SourceInteractive || got.Confidence != 1.0 {
t.Fatalf("interactive result = %+v", got)
}
if len(fp.gotCands) == 0 {
t.Error("prompter received no candidates")
}
}
func TestDetectInteractiveDescribeRoutesToAI(t *testing.T) {
cat := mustCatalog(t)
root := makeRepo(t, "package.json")
fp := &fakePrompter{describe: true, freeText: "a REST service"}
ai := func(_ context.Context, prompt string) (string, error) {
if !strings.Contains(prompt, "a REST service") {
t.Errorf("prompt missing operator description; got:\n%s", prompt)
}
return `here you go: {"category":"webapi","confidence":0.9,"dirs":["endpoints","models"],"deviations":["queue"]}`, nil
}
got, err := Detect(context.Background(), cat, Options{RepoRoot: root, Prompter: fp, AI: ai})
if err != nil {
t.Fatalf("Detect: %v", err)
}
if got.Category != WebAPI || got.Source != SourceAI {
t.Fatalf("AI result = %+v", got)
}
if !containsDir(got.Dirs, "queue") {
t.Errorf("expected AI deviation 'queue' in dirs %v", got.Dirs)
}
if !containsDir(got.Dirs, "database") {
t.Errorf("expected catalog dir 'database' retained in dirs %v", got.Dirs)
}
}
func TestDetectForceAI(t *testing.T) {
cat := mustCatalog(t)
root := makeRepo(t, "go.mod", "cmd/")
ai := func(context.Context, string) (string, error) {
return `{"category":"cli","confidence":0.85,"dirs":["commands"]}`, nil
}
got, err := Detect(context.Background(), cat, Options{RepoRoot: root, ForceAI: true, AI: ai})
if err != nil {
t.Fatalf("Detect: %v", err)
}
if got.Category != CLI || got.Source != SourceAI {
t.Fatalf("force-ai result = %+v", got)
}
}
func TestAIDegradesToGeneric(t *testing.T) {
cat := mustCatalog(t)
root := makeRepo(t, "README.md")
cases := []struct {
name string
ai AIFunc
}{
{"nil-aifunc", nil},
{"low-confidence", func(context.Context, string) (string, error) {
return `{"category":"cli","confidence":0.3}`, nil
}},
{"malformed", func(context.Context, string) (string, error) {
return "I cannot help with that", nil
}},
{"unknown-category", func(context.Context, string) (string, error) {
return `{"category":"banana","confidence":0.99}`, nil
}},
{"provider-error", func(context.Context, string) (string, error) {
return "", context.DeadlineExceeded
}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := Detect(context.Background(), cat, Options{RepoRoot: root, ForceAI: true, AI: tc.ai})
if err != nil {
t.Fatalf("Detect: %v", err)
}
if got.Category != Generic || got.Source != SourceFallback {
t.Fatalf("result = %+v, want generic fallback", got)
}
})
}
}
func TestConfidence(t *testing.T) {
cases := []struct {
top, second, want float64
}{
{0, 0, 0},
{1.0, 0, 1.0},
{1.0, 1.0, 0.5},
{0.8, 0.2, 0.8},
}
for _, tc := range cases {
if got := confidence(tc.top, tc.second); got != tc.want {
t.Errorf("confidence(%.2f,%.2f) = %.4f, want %.4f", tc.top, tc.second, got, tc.want)
}
}
}
func TestMergeDirs(t *testing.T) {
got := mergeDirs([]string{"a", "b"}, []string{"b", "c", "", " a "})
want := []string{"a", "b", "c"}
if len(got) != len(want) {
t.Fatalf("mergeDirs = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("mergeDirs = %v, want %v", got, want)
}
}
}
func TestParseAIDetectFromProse(t *testing.T) {
raw := "Sure!\n```json\n{\"category\":\"ml\",\"confidence\":0.7,\"deviations\":[\"data\"]}\n```\nhope that helps"
d, ok := parseAIDetect(raw)
if !ok {
t.Fatal("parseAIDetect failed to extract object")
}
if d.Category != "ml" || d.Confidence != 0.7 {
t.Fatalf("parsed = %+v", d)
}
if !containsDir(d.Dirs, "data") {
t.Errorf("deviation not merged into dirs: %v", d.Dirs)
}
}
func containsDir(dirs []string, want string) bool {
return slices.Contains(dirs, want)
}