Add frameworks query parameter to /v1/metrics endpoint

Co-authored-by: matthyx <20683409+matthyx@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-04 13:45:55 +00:00
parent 1f7dd6e5f5
commit 8d59a6074e
2 changed files with 77 additions and 12 deletions

View File

@@ -8,6 +8,7 @@ import (
"path/filepath"
"strings"
"github.com/gorilla/schema"
"github.com/google/uuid"
"github.com/kubescape/go-logger"
"github.com/kubescape/go-logger/helpers"
@@ -18,6 +19,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
// MetricsQueryParams query params for metrics endpoint
type MetricsQueryParams struct {
// Frameworks is a comma-separated list of frameworks to scan
// Example: "nsa,mitre,cis-v1.10.0"
// If not provided, all available frameworks will be scanned
Frameworks string `schema:"frameworks" json:"frameworks"`
}
// Metrics http listener for prometheus support
func (handler *HTTPHandler) Metrics(w http.ResponseWriter, r *http.Request) {
@@ -25,8 +34,16 @@ func (handler *HTTPHandler) Metrics(w http.ResponseWriter, r *http.Request) {
handler.state.setBusy(scanID)
defer handler.state.setNotBusy(scanID)
// Parse query parameters
metricsQueryParams := &MetricsQueryParams{}
if err := schema.NewDecoder().Decode(metricsQueryParams, r.URL.Query()); err != nil {
w.WriteHeader(http.StatusBadRequest)
handler.writeError(w, fmt.Errorf("failed to parse query params, reason: %s", err.Error()), scanID)
return
}
resultsFile := filepath.Join(OutputDir, scanID)
scanInfo := getPrometheusDefaultScanCommand(scanID, resultsFile)
scanInfo := getPrometheusDefaultScanCommand(scanID, resultsFile, metricsQueryParams.Frameworks)
scanParams := &scanRequestParams{
scanQueryParams: &ScanQueryParams{
@@ -69,7 +86,7 @@ func (handler *HTTPHandler) Metrics(w http.ResponseWriter, r *http.Request) {
w.Write(f)
}
func getPrometheusDefaultScanCommand(scanID, resultsFile string) *cautils.ScanInfo {
func getPrometheusDefaultScanCommand(scanID, resultsFile, frameworksParam string) *cautils.ScanInfo {
scanInfo := defaultScanInfo()
scanInfo.UseArtifactsFrom = getter.DefaultLocalStore // Load files from cache (this will prevent kubescape fom downloading the artifacts every time)
scanInfo.Submit = false // do not submit results every scan
@@ -82,11 +99,16 @@ func getPrometheusDefaultScanCommand(scanID, resultsFile string) *cautils.ScanIn
scanInfo.Output = resultsFile // results output
scanInfo.Format = envToString("KS_FORMAT", "prometheus") // default output format is prometheus
// Check if specific frameworks are requested via environment variable
frameworksEnv := envToString("KS_METRICS_FRAMEWORKS", "")
if frameworksEnv != "" {
// Check if specific frameworks are requested
// Priority: 1) query parameter, 2) environment variable, 3) default (all frameworks)
frameworksList := frameworksParam
if frameworksList == "" {
frameworksList = envToString("KS_METRICS_FRAMEWORKS", "")
}
if frameworksList != "" {
// Scan specific frameworks (comma-separated list)
frameworks := splitAndTrim(frameworksEnv, ",")
frameworks := splitAndTrim(frameworksList, ",")
scanInfo.SetPolicyIdentifiers(frameworks, utilsapisv1.KindFramework)
} else {
// Default: scan all available frameworks (including CIS)

View File

@@ -16,7 +16,7 @@ func TestGetPrometheusDefaultScanCommand(t *testing.T) {
scanID := "1234"
outputFile := filepath.Join(OutputDir, scanID)
scanInfo := getPrometheusDefaultScanCommand(scanID, outputFile)
scanInfo := getPrometheusDefaultScanCommand(scanID, outputFile, "")
assert.Equal(t, scanID, scanInfo.ScanID)
assert.Equal(t, outputFile, scanInfo.Output)
@@ -29,14 +29,13 @@ func TestGetPrometheusDefaultScanCommand(t *testing.T) {
assert.Equal(t, getter.DefaultLocalStore, scanInfo.UseArtifactsFrom)
})
t.Run("specific frameworks via environment variable", func(t *testing.T) {
// Set environment variable to scan specific frameworks
os.Setenv("KS_METRICS_FRAMEWORKS", "nsa,mitre,cis-v1.10.0")
defer os.Unsetenv("KS_METRICS_FRAMEWORKS")
t.Run("specific frameworks via query parameter", func(t *testing.T) {
// Ensure environment variable is not set
os.Unsetenv("KS_METRICS_FRAMEWORKS")
scanID := "5678"
outputFile := filepath.Join(OutputDir, scanID)
scanInfo := getPrometheusDefaultScanCommand(scanID, outputFile)
scanInfo := getPrometheusDefaultScanCommand(scanID, outputFile, "nsa,mitre,cis-v1.10.0")
assert.Equal(t, scanID, scanInfo.ScanID)
assert.Equal(t, outputFile, scanInfo.Output)
@@ -54,6 +53,50 @@ func TestGetPrometheusDefaultScanCommand(t *testing.T) {
assert.Equal(t, "mitre", scanInfo.PolicyIdentifier[1].Identifier)
assert.Equal(t, "cis-v1.10.0", scanInfo.PolicyIdentifier[2].Identifier)
})
t.Run("specific frameworks via environment variable", func(t *testing.T) {
// Set environment variable to scan specific frameworks
os.Setenv("KS_METRICS_FRAMEWORKS", "nsa,mitre")
defer os.Unsetenv("KS_METRICS_FRAMEWORKS")
scanID := "9012"
outputFile := filepath.Join(OutputDir, scanID)
scanInfo := getPrometheusDefaultScanCommand(scanID, outputFile, "")
assert.Equal(t, scanID, scanInfo.ScanID)
assert.Equal(t, outputFile, scanInfo.Output)
assert.Equal(t, "prometheus", scanInfo.Format)
assert.False(t, scanInfo.Submit)
assert.True(t, scanInfo.Local)
assert.True(t, scanInfo.FrameworkScan)
assert.False(t, scanInfo.ScanAll) // Don't scan all when specific frameworks are set
assert.False(t, scanInfo.HostSensorEnabled.GetBool())
assert.Equal(t, getter.DefaultLocalStore, scanInfo.UseArtifactsFrom)
// Verify specific frameworks are set
assert.Len(t, scanInfo.PolicyIdentifier, 2)
assert.Equal(t, "nsa", scanInfo.PolicyIdentifier[0].Identifier)
assert.Equal(t, "mitre", scanInfo.PolicyIdentifier[1].Identifier)
})
t.Run("query parameter overrides environment variable", func(t *testing.T) {
// Set environment variable
os.Setenv("KS_METRICS_FRAMEWORKS", "nsa")
defer os.Unsetenv("KS_METRICS_FRAMEWORKS")
scanID := "3456"
outputFile := filepath.Join(OutputDir, scanID)
// Query parameter should override environment variable
scanInfo := getPrometheusDefaultScanCommand(scanID, outputFile, "mitre,cis-v1.10.0")
assert.Equal(t, scanID, scanInfo.ScanID)
assert.False(t, scanInfo.ScanAll)
// Verify query parameter frameworks are used, not env var
assert.Len(t, scanInfo.PolicyIdentifier, 2)
assert.Equal(t, "mitre", scanInfo.PolicyIdentifier[0].Identifier)
assert.Equal(t, "cis-v1.10.0", scanInfo.PolicyIdentifier[1].Identifier)
})
}
func TestSplitAndTrim(t *testing.T) {