Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions cmd/aibomgen-cli/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func runGenerate(cmd *cobra.Command, args []string) error {

// Print summary.
if len(written) == 0 {
genUI.PrintNoModelsFound()
genUI.PrintNoBOMsWritten()
return nil
}

Expand Down Expand Up @@ -245,7 +245,7 @@ func runModelIDMode(genUI *ui.GenerateUI, modelIDs []string, mode, hfToken strin

if !quiet {
workflow = ui.NewWorkflow(os.Stdout, "")
processTaskIdx = workflow.AddTask("Processing models")
processTaskIdx = workflow.AddTask("Processing possible models")
writeTaskIdx = workflow.AddTask("Writing output")
workflow.Start()
}
Expand Down Expand Up @@ -327,7 +327,7 @@ func runModelIDMode(genUI *ui.GenerateUI, modelIDs []string, mode, hfToken strin
}

if !quiet && workflow != nil {
workflow.CompleteTask(processTaskIdx, fmt.Sprintf("%d model(s)", len(boms)))
workflow.CompleteTask(processTaskIdx, fmt.Sprintf("%d possible model(s)", len(modelIDs)))
workflow.StartTask(writeTaskIdx, "")
workflow.CompleteTask(writeTaskIdx, fmt.Sprintf("%d file(s)", len(boms)))
workflow.Stop()
Expand Down Expand Up @@ -389,14 +389,17 @@ type modelTracker struct {
// detail is empty for a clean success (datasets are shown on sub-lines instead).
func modelOutcome(t *modelTracker, hasToken bool) (mark, detail string) {
switch {
case t == nil || !t.complete:
case t == nil:
return ui.GetCrossMark(), ui.Error.Render("→ BOM build failed")

case !t.apiOK && t.notFound:
case t.notFound && !t.apiOK:
if hasToken {
return ui.GetWarnMark(), ui.Warning.Render("→ not found on HF Hub")
return ui.GetCrossMark(), ui.Error.Render("→ not found on HF Hub; no BOM written")
}
return ui.GetWarnMark(), ui.Warning.Render("→ not found on HF Hub (or private – set --hf-token)")
return ui.GetCrossMark(), ui.Error.Render("→ not found (or private – set --hf-token); no BOM written")

Comment thread
wiebe-vandendriessche marked this conversation as resolved.
case !t.complete:
return ui.GetCrossMark(), ui.Error.Render("→ BOM build failed")

case t.fetchErr:
if fetcher.IsUnauthorized(t.fetchErrVal) {
Expand Down
10 changes: 5 additions & 5 deletions cmd/aibomgen-cli/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func runScan(cmd *cobra.Command, args []string) error {
// Print summary.
if len(written) == 0 {
genUI := ui.NewGenerateUI(cmd.OutOrStdout(), quiet)
genUI.PrintNoModelsFound()
genUI.PrintNoBOMsWritten()
return nil
}

Expand Down Expand Up @@ -202,8 +202,8 @@ func runScanDirectory(inputPath, mode, hfToken string, timeout time.Duration, qu

if !quiet {
workflow = ui.NewWorkflow(os.Stdout, "")
scanTaskIdx = workflow.AddTask("Scanning for AI imports")
processTaskIdx = workflow.AddTask("Processing models")
scanTaskIdx = workflow.AddTask("Scanning for possible AI imports")
processTaskIdx = workflow.AddTask("Processing possible models")
writeTaskIdx = workflow.AddTask("Writing output")
workflow.Start()
}
Expand All @@ -223,7 +223,7 @@ func runScanDirectory(inputPath, mode, hfToken string, timeout time.Duration, qu
}

if !quiet && workflow != nil {
workflow.CompleteTask(scanTaskIdx, fmt.Sprintf("found %d model(s)", len(discoveries)))
workflow.CompleteTask(scanTaskIdx, fmt.Sprintf("found %d possible model(s)", len(discoveries)))
}

if len(discoveries) == 0 {
Expand Down Expand Up @@ -309,7 +309,7 @@ func runScanDirectory(inputPath, mode, hfToken string, timeout time.Duration, qu
}

if !quiet && workflow != nil {
workflow.CompleteTask(processTaskIdx, fmt.Sprintf("%d model(s)", len(boms)))
workflow.CompleteTask(processTaskIdx, fmt.Sprintf("%d possible model(s)", len(discoveries)))
workflow.StartTask(writeTaskIdx, "")
workflow.CompleteTask(writeTaskIdx, fmt.Sprintf("%d file(s)", len(boms)))
workflow.Stop()
Expand Down
8 changes: 4 additions & 4 deletions internal/ui/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,14 @@ func (g *GenerateUI) PrintSummary(filesWritten int, outputDir, format string) {
fmt.Fprintln(g.writer, SuccessBox.Render(summary.String()))
}

// PrintNoModelsFound prints a message when no models are found.
func (g *GenerateUI) PrintNoModelsFound() {
// PrintNoBOMsWritten prints a message when no BOMs were written.
func (g *GenerateUI) PrintNoBOMsWritten() {
if g.quiet {
return
}

msg := "No models detected; no AIBOM files written."
fmt.Fprintln(g.writer, Warning.Render(GetWarnMark()+" "+msg))
msg := "No BOMs written."
fmt.Fprintln(g.writer, Error.Render(GetCrossMark()+" "+msg))
Comment thread
wiebe-vandendriessche marked this conversation as resolved.
}

// LogStep prints a simple log message (non-workflow mode).
Expand Down
24 changes: 22 additions & 2 deletions pkg/aibomgen/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,25 @@ func BuildPerDiscovery(discoveries []scanner.Discovery, opts GenerateOptions) ([

var resp *fetcher.ModelAPIResponse
var readme *fetcher.ModelReadmeCard
var apiNotFound bool

if modelID != "" {
if r, err := fetchers.modelAPI.Fetch(modelID); err == nil {
resp = r
progress(ProgressEvent{Type: EventFetchAPIComplete, ModelID: modelID})
} else {
if fetcher.IsNotFound(err) || fetcher.IsUnauthorized(err) {
apiNotFound = true
}
progress(ProgressEvent{Type: EventError, ModelID: modelID, Error: err, Message: fetchErrMessage("API", err)})
}

// Skip BOM generation if API fetch returned not found or unauthorized (model not accessible on HF)
if apiNotFound {
Comment thread
wiebe-vandendriessche marked this conversation as resolved.
progress(ProgressEvent{Type: EventModelComplete, ModelID: modelID, Message: "model skipped: API not found or unauthorized"})
continue
}

if c, err := fetchers.modelReadme.Fetch(modelID); err == nil {
readme = c
progress(ProgressEvent{Type: EventFetchReadmeComplete, ModelID: modelID})
Expand Down Expand Up @@ -385,19 +395,29 @@ func BuildFromModelIDs(modelIDs []string, opts GenerateOptions) ([]DiscoveredBOM
continue
}

bomBuilder := newBOMBuilder()

progress(ProgressEvent{Type: EventFetchStart, ModelID: modelID, Index: i, Total: len(modelIDs)})

// Fetch API metadata.
resp, err := fetchers.modelAPI.Fetch(modelID)
var apiNotFound bool
if err != nil {
if fetcher.IsNotFound(err) || fetcher.IsUnauthorized(err) {
apiNotFound = true
}
progress(ProgressEvent{Type: EventError, ModelID: modelID, Error: err, Message: "API fetch failed"})
resp = nil
} else {
progress(ProgressEvent{Type: EventFetchAPIComplete, ModelID: modelID})
}

// Skip BOM generation if API fetch returned not found or unauthorized (model not accessible on HF)
if apiNotFound {
Comment thread
wiebe-vandendriessche marked this conversation as resolved.
progress(ProgressEvent{Type: EventModelComplete, ModelID: modelID, Message: "model skipped: API not found or unauthorized"})
continue
}

bomBuilder := newBOMBuilder()

// Fetch README.
readme, err := fetchers.modelReadme.Fetch(modelID)
if err != nil {
Expand Down
154 changes: 154 additions & 0 deletions pkg/aibomgen/generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,32 @@ func (m *mockDatasetReadmeFetcher) Fetch(id string) (*fetcher.DatasetReadmeCard,
return &fetcher.DatasetReadmeCard{}, nil
}

func successFetcherSet() fetcherSet {
return fetcherSet{
modelAPI: &mockModelAPIFetcher{
fetchFunc: func(id string) (*fetcher.ModelAPIResponse, error) {
return &fetcher.ModelAPIResponse{ID: id}, nil
},
},
modelReadme: &mockModelReadmeFetcher{
fetchFunc: func(id string) (*fetcher.ModelReadmeCard, error) {
return &fetcher.ModelReadmeCard{}, nil
},
},
datasetAPI: &mockDatasetAPIFetcher{
fetchFunc: func(id string) (*fetcher.DatasetAPIResponse, error) {
return &fetcher.DatasetAPIResponse{ID: id}, nil
},
},
datasetReadme: &mockDatasetReadmeFetcher{
fetchFunc: func(id string) (*fetcher.DatasetReadmeCard, error) {
return &fetcher.DatasetReadmeCard{}, nil
},
},
modelTree: &fetcher.DummyModelTreeFetcher{},
}
}

func TestBuildDummyBOM(t *testing.T) {
// Save originals.
originalDummyFetcherSet := newDummyFetcherSet
Expand Down Expand Up @@ -271,6 +297,9 @@ func TestBuildPerDiscovery(t *testing.T) {
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return successFetcherSet()
}
Comment thread
wiebe-vandendriessche marked this conversation as resolved.
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
Expand Down Expand Up @@ -311,6 +340,9 @@ func TestBuildPerDiscovery(t *testing.T) {
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return successFetcherSet()
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
Expand All @@ -335,6 +367,9 @@ func TestBuildPerDiscovery(t *testing.T) {
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return successFetcherSet()
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
Expand Down Expand Up @@ -362,6 +397,9 @@ func TestBuildPerDiscovery(t *testing.T) {
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return successFetcherSet()
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
Expand Down Expand Up @@ -613,6 +651,42 @@ func TestBuildPerDiscovery(t *testing.T) {
}
},
},
{
name: "skips model when API returns 401 (unauthorized) (model not found is also 401)",
args: args{
discoveries: []scanner.Discovery{
{ID: "private-or-missing-model", Name: "private-or-missing-model", Type: "huggingface"},
},
opts: GenerateOptions{Timeout: 1 * time.Second},
},
setup: func() {
newBOMBuilder = func() bomBuilder {
return &mockBOMBuilder{
buildFunc: func(bctx builder.BuildContext) (*cdx.BOM, error) {
return &cdx.BOM{}, nil
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
Comment thread
wiebe-vandendriessche marked this conversation as resolved.
return fetcherSet{
modelAPI: &mockModelAPIFetcher{
fetchFunc: func(id string) (*fetcher.ModelAPIResponse, error) {
return nil, &fetcher.HFError{StatusCode: 401}
},
},
modelReadme: &mockModelReadmeFetcher{},
datasetAPI: &mockDatasetAPIFetcher{},
datasetReadme: &mockDatasetReadmeFetcher{},
}
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
if len(got) != 0 {
t.Errorf("Expected 0 BOMs for model unauthorized on HF (401), got %d", len(got))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -766,6 +840,9 @@ func TestBuildFromModelIDs(t *testing.T) {
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return successFetcherSet()
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
Expand All @@ -792,6 +869,9 @@ func TestBuildFromModelIDs(t *testing.T) {
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return successFetcherSet()
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
Expand Down Expand Up @@ -833,6 +913,9 @@ func TestBuildFromModelIDs(t *testing.T) {
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return successFetcherSet()
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
Expand Down Expand Up @@ -860,6 +943,9 @@ func TestBuildFromModelIDs(t *testing.T) {
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return successFetcherSet()
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
Expand Down Expand Up @@ -982,6 +1068,74 @@ func TestBuildFromModelIDs(t *testing.T) {
}
},
},
{
name: "skips model when API returns 404 (not found)",
args: args{
modelIDs: []string{"org/nonexistent-model"},
opts: GenerateOptions{Timeout: 1 * time.Second},
},
setup: func() {
newBOMBuilder = func() bomBuilder {
return &mockBOMBuilder{
buildFunc: func(bctx builder.BuildContext) (*cdx.BOM, error) {
return &cdx.BOM{}, nil
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return fetcherSet{
modelAPI: &mockModelAPIFetcher{
fetchFunc: func(id string) (*fetcher.ModelAPIResponse, error) {
return nil, &fetcher.HFError{StatusCode: 404}
},
},
modelReadme: &mockModelReadmeFetcher{},
datasetAPI: &mockDatasetAPIFetcher{},
datasetReadme: &mockDatasetReadmeFetcher{},
}
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
if len(got) != 0 {
t.Errorf("Expected 0 BOMs for model not found on HF (404), got %d", len(got))
}
},
},
{
name: "skips model when API returns 401 (unauthorized)",
args: args{
modelIDs: []string{"org/private-or-missing-model"},
opts: GenerateOptions{Timeout: 1 * time.Second},
},
setup: func() {
newBOMBuilder = func() bomBuilder {
return &mockBOMBuilder{
buildFunc: func(bctx builder.BuildContext) (*cdx.BOM, error) {
return &cdx.BOM{}, nil
},
}
}
newFetcherSet = func(httpClient *http.Client) fetcherSet {
return fetcherSet{
modelAPI: &mockModelAPIFetcher{
fetchFunc: func(id string) (*fetcher.ModelAPIResponse, error) {
return nil, &fetcher.HFError{StatusCode: 401}
},
},
modelReadme: &mockModelReadmeFetcher{},
datasetAPI: &mockDatasetAPIFetcher{},
datasetReadme: &mockDatasetReadmeFetcher{},
}
}
},
wantErr: false,
check: func(t *testing.T, got []DiscoveredBOM) {
if len(got) != 0 {
t.Errorf("Expected 0 BOMs for model unauthorized on HF (401), got %d", len(got))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
Loading