use channels for triggering scan

This commit is contained in:
David Wertenteil
2022-05-19 14:18:11 +03:00
parent d08fdf2e9e
commit 3abd59e290
6 changed files with 293 additions and 88 deletions

View File

@@ -7,13 +7,13 @@ Running `kubescape` will start up a webserver on port `8080` which will serve th
### Trigger scan
* POST `/v1/scan` - Trigger a kubescape scan. The server will return an ID and will execute the scanning asynchronously
* * `wait`: scan synchronously (return results and not ID). Use only in small clusters are with an increased timeout
* * `keep`: Do not delete results from local storage after returning
* * `wait=true`: scan synchronously (return results and not ID). Use only in small clusters are with an increased timeout. default is `wait=false`
* * `keep=true`: Do not delete results from local storage after returning. default is `keep=false`
### Get results
* GET `/v1/results` - Request kubescape scan results
* * query `id=<string>` -> ID returned when triggering the scan action. If empty will return latest results
* * query `keep` -> Do not delete results from local storage after returning
* * query `keep=true` -> Do not delete results from local storage after returning. default is `keep=false`
### Check scanning progress status
Check the scanning status - is the scanning in progress or done. This is meant for a waiting mechanize since the API does not return the entire results object when the scanning is done

View File

@@ -0,0 +1,107 @@
package v1
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"sync"
"github.com/armosec/kubescape/v2/core/cautils/logger"
"github.com/armosec/kubescape/v2/core/cautils/logger/helpers"
utilsmetav1 "github.com/armosec/opa-utils/httpserver/meta/v1"
"github.com/gorilla/schema"
)
type scanResponseChan struct {
scanResponseChan map[string]chan *utilsmetav1.Response
mtx sync.RWMutex
}
// get response object chan
func (resChan *scanResponseChan) get(key string) chan *utilsmetav1.Response {
resChan.mtx.RLock()
defer resChan.mtx.RUnlock()
return resChan.scanResponseChan[key]
}
// set chan for response object
func (resChan *scanResponseChan) set(key string) {
resChan.mtx.Lock()
defer resChan.mtx.Unlock()
resChan.scanResponseChan[key] = make(chan *utilsmetav1.Response)
}
// push response object to chan
func (resChan *scanResponseChan) push(key string, resp *utilsmetav1.Response) {
resChan.mtx.Lock()
defer resChan.mtx.Unlock()
if _, ok := resChan.scanResponseChan[key]; ok {
resChan.scanResponseChan[key] <- resp
}
}
// delete channel
func (resChan *scanResponseChan) delete(key string) {
resChan.mtx.Lock()
defer resChan.mtx.Unlock()
delete(resChan.scanResponseChan, key)
}
func newScanResponseChan() *scanResponseChan {
return &scanResponseChan{
scanResponseChan: make(map[string]chan *utilsmetav1.Response),
mtx: sync.RWMutex{},
}
}
type ScanQueryParams struct {
ReturnResults bool `schema:"wait"` // wait for scanning to complete (synchronized request)
KeepResults bool `schema:"keep"` // do not delete results after returning (relevant only for synchronized requests)
}
type ResultsQueryParams struct {
ScanID string `schema:"id"`
KeepResults bool `schema:"keep"` // do not delete results after returning (default will delete results)
AllResults bool `schema:"all"` // delete all results
}
type StatusQueryParams struct {
ScanID string `schema:"id"`
}
// scanRequestParams params passed to channel
type scanRequestParams struct {
scanRequest *utilsmetav1.PostScanRequest // request as received from api
scanQueryParams *ScanQueryParams // request as received from api
scanID string // generated scan ID
}
func getScanParamsFromRequest(r *http.Request, scanID string) (*scanRequestParams, error) {
defer r.Body.Close()
scanRequestParams := &scanRequestParams{}
scanQueryParams := &ScanQueryParams{}
if err := schema.NewDecoder().Decode(scanQueryParams, r.URL.Query()); err != nil {
return scanRequestParams, fmt.Errorf("failed to parse query params, reason: %s", err.Error())
}
readBuffer, err := ioutil.ReadAll(r.Body)
if err != nil {
// handler.writeError(w, fmt.Errorf("failed to read request body, reason: %s", err.Error()), scanID)
return scanRequestParams, fmt.Errorf("failed to read request body, reason: %s", err.Error())
}
logger.L().Info("REST API received scan request", helpers.String("body", string(readBuffer)))
scanRequest := &utilsmetav1.PostScanRequest{}
if err := json.Unmarshal(readBuffer, &scanRequest); err != nil {
return scanRequestParams, fmt.Errorf("failed to parse request payload, reason: %s", err.Error())
}
scanRequestParams.scanID = scanID
scanRequestParams.scanQueryParams = scanQueryParams
scanRequestParams.scanRequest = scanRequest
return scanRequestParams, nil
}

View File

@@ -0,0 +1,76 @@
package v1
import (
"bytes"
"encoding/json"
"net/http"
"net/url"
"testing"
utilsmetav1 "github.com/armosec/opa-utils/httpserver/meta/v1"
"github.com/armosec/utils-go/boolutils"
"github.com/stretchr/testify/assert"
)
func TestGetScanParamsFromRequest(t *testing.T) {
{
body := utilsmetav1.PostScanRequest{
Submit: boolutils.BoolPointer(true),
HostScanner: boolutils.BoolPointer(true),
Account: "aaaaaaaaaa",
}
jsonBytes, err := json.Marshal(body)
assert.NoError(t, err)
u := url.URL{
Scheme: "http",
Host: "bla",
Path: "bla",
RawQuery: "wait=true&keep=true",
}
request, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(jsonBytes))
assert.NoError(t, err)
scanID := "ccccccc"
req, err := getScanParamsFromRequest(request, scanID)
assert.NoError(t, err)
assert.Equal(t, scanID, req.scanID)
assert.True(t, req.scanQueryParams.KeepResults)
assert.True(t, req.scanQueryParams.ReturnResults)
assert.True(t, *req.scanRequest.HostScanner)
assert.True(t, *req.scanRequest.Submit)
assert.Equal(t, "aaaaaaaaaa", req.scanRequest.Account)
}
{
body := utilsmetav1.PostScanRequest{
Submit: boolutils.BoolPointer(false),
HostScanner: boolutils.BoolPointer(false),
Account: "aaaaaaaaaa",
}
jsonBytes, err := json.Marshal(body)
assert.NoError(t, err)
u := url.URL{
Scheme: "http",
Host: "bla",
Path: "bla",
}
request, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(jsonBytes))
assert.NoError(t, err)
scanID := "ccccccc"
req, err := getScanParamsFromRequest(request, scanID)
assert.NoError(t, err)
assert.Equal(t, scanID, req.scanID)
assert.False(t, req.scanQueryParams.KeepResults)
assert.False(t, req.scanQueryParams.ReturnResults)
assert.False(t, *req.scanRequest.HostScanner)
assert.False(t, *req.scanRequest.Submit)
assert.Equal(t, "aaaaaaaaaa", req.scanRequest.Account)
}
}

View File

@@ -1,11 +1,8 @@
package v1
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"sync"
utilsapisv1 "github.com/armosec/opa-utils/httpserver/apis/v1"
utilsmetav1 "github.com/armosec/opa-utils/httpserver/meta/v1"
@@ -19,29 +16,21 @@ import (
var OutputDir = "./results"
var FailedOutputDir = "./failed"
type ScanQueryParams struct {
ReturnResults bool `schema:"wait"` // wait for scanning to complete (synchronized request)
KeepResults bool `schema:"keep"` // do not delete results after returning (relevant only for synchronized requests)
}
type ResultsQueryParams struct {
ScanID string `schema:"id"`
KeepResults bool `schema:"keep"` // do not delete results after returning (default will delete results)
AllResults bool `schema:"all"` // delete all results
}
type StatusQueryParams struct {
ScanID string `schema:"id"`
}
type HTTPHandler struct {
state *serverState
state *serverState
scanResponseChan *scanResponseChan
scanRequestChan chan *scanRequestParams
}
func NewHTTPHandler() *HTTPHandler {
return &HTTPHandler{
state: newServerState(),
handler := &HTTPHandler{
state: newServerState(),
scanRequestChan: make(chan *scanRequestParams),
scanResponseChan: newScanResponseChan(),
}
go handler.executeScan()
return handler
}
// ============================================== STATUS ========================================================
@@ -80,7 +69,7 @@ func (handler *HTTPHandler) Status(w http.ResponseWriter, r *http.Request) {
}
// ============================================== SCAN ========================================================
// Scan API - TODO: break down to functions
// Scan API
func (handler *HTTPHandler) Scan(w http.ResponseWriter, r *http.Request) {
// generate id
@@ -88,86 +77,52 @@ func (handler *HTTPHandler) Scan(w http.ResponseWriter, r *http.Request) {
defer handler.recover(w, scanID)
defer r.Body.Close()
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
response := utilsmetav1.Response{}
w.Header().Set("Content-Type", "application/json")
scanQueryParams := &ScanQueryParams{}
if err := schema.NewDecoder().Decode(scanQueryParams, r.URL.Query()); err != nil {
handler.writeError(w, fmt.Errorf("failed to parse query params, reason: %s", err.Error()), scanID)
scanRequestParams, err := getScanParamsFromRequest(r, scanID)
if err != nil {
handler.writeError(w, err, "")
return
}
handler.state.setBusy(scanID)
// Add to queue
response := &utilsmetav1.Response{}
response.ID = scanID
response.Type = utilsapisv1.IDScanResponseType
response.Type = utilsapisv1.BusyScanResponseType
response.Response = fmt.Sprintf("scanning '%s' is in progress", scanID)
readBuffer, err := ioutil.ReadAll(r.Body)
if err != nil {
handler.writeError(w, fmt.Errorf("failed to read request body, reason: %s", err.Error()), scanID)
return
}
handler.scanResponseChan.set(scanID) // add channel
defer handler.scanResponseChan.delete(scanID)
logger.L().Info("REST API received scan request", helpers.String("body", string(readBuffer)))
scanRequest := utilsmetav1.PostScanRequest{}
if err := json.Unmarshal(readBuffer, &scanRequest); err != nil {
handler.writeError(w, fmt.Errorf("failed to parse request payload, reason: %s", err.Error()), scanID)
return
}
var wg sync.WaitGroup
if scanQueryParams.ReturnResults {
wg.Add(1)
} else {
wg.Add(0)
}
statusCode := http.StatusOK
// you must use a goroutine since the executeScan function is not always listening to the channel
go func() {
// execute scan in the background
// send to scanning handler
handler.scanRequestChan <- scanRequestParams
}()
logger.L().Info("scan triggered", helpers.String("ID", scanID))
if scanRequestParams.scanQueryParams.ReturnResults {
// wait for scan to complete
response = <-handler.scanResponseChan.get(scanID)
results, err := scan(&scanRequest, scanID)
if err != nil {
logger.L().Error("scanning failed", helpers.String("ID", scanID), helpers.Error(err))
if scanQueryParams.ReturnResults {
response.Type = utilsapisv1.ErrorScanResponseType
response.Response = err.Error()
statusCode = http.StatusInternalServerError
}
} else {
logger.L().Success("done scanning", helpers.String("ID", scanID))
if scanQueryParams.ReturnResults {
response.Type = utilsapisv1.ResultsV1ScanResponseType
response.Response = results
wg.Done()
}
}
if scanQueryParams.ReturnResults && !scanQueryParams.KeepResults {
if scanRequestParams.scanQueryParams.KeepResults {
// delete results after returning
logger.L().Debug("deleting results", helpers.String("ID", scanID))
removeResultsFile(scanID)
}
handler.state.setNotBusy(scanID)
}()
}
wg.Wait()
statusCode := http.StatusOK
if response.Type == utilsapisv1.ErrorScanResponseType {
statusCode = http.StatusInternalServerError
}
w.WriteHeader(statusCode)
w.Write(responseToBytes(&response))
}
func (handler *HTTPHandler) scan() {
for {
}
w.Write(responseToBytes(response))
}
// ============================================== RESULTS ========================================================
@@ -204,6 +159,7 @@ func (handler *HTTPHandler) Results(w http.ResponseWriter, r *http.Request) {
if handler.state.isBusy(resultsQueryParams.ScanID) { // if requested ID is still scanning
logger.L().Info("scan in process", helpers.String("ID", resultsQueryParams.ScanID))
w.WriteHeader(http.StatusOK)
response.Type = utilsapisv1.BusyScanResponseType
response.Response = fmt.Sprintf("scanning '%s' in progress", resultsQueryParams.ScanID)
w.Write(responseToBytes(&response))
return
@@ -253,11 +209,6 @@ func (handler *HTTPHandler) Ready(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func responseToBytes(res *utilsmetav1.Response) []byte {
b, _ := json.Marshal(res)
return b
}
func (handler *HTTPHandler) recover(w http.ResponseWriter, scanID string) {
response := utilsmetav1.Response{}
if err := recover(); err != nil {

View File

@@ -0,0 +1,32 @@
package v1
// ============================================== STATUS ========================================================
// Status API
// func TestStatus(t *testing.T) {
// {
// httpHandler := NewHTTPHandler()
// u := url.URL{
// Scheme: "http",
// Host: "bla",
// Path: "bla",
// RawQuery: "wait=true&keep=true",
// }
// request, err := http.NewRequest(http.MethodPost, u.String(), nil)
// httpHandler.Status(nil, request)
// assert.NoError(t, err)
// scanID := "ccccccc"
// req, err := getScanParamsFromRequest(request, scanID)
// assert.NoError(t, err)
// assert.Equal(t, scanID, req.scanID)
// assert.True(t, req.scanQueryParams.KeepResults)
// assert.True(t, req.scanQueryParams.ReturnResults)
// assert.True(t, *req.scanRequest.HostScanner)
// assert.True(t, *req.scanRequest.Submit)
// assert.Equal(t, "aaaaaaaaaa", req.scanRequest.Account)
// }
// }

View File

@@ -9,12 +9,45 @@ import (
"github.com/armosec/kubescape/v2/core/cautils"
"github.com/armosec/kubescape/v2/core/cautils/getter"
"github.com/armosec/kubescape/v2/core/cautils/logger"
"github.com/armosec/kubescape/v2/core/cautils/logger/helpers"
"github.com/armosec/kubescape/v2/core/core"
utilsapisv1 "github.com/armosec/opa-utils/httpserver/apis/v1"
utilsmetav1 "github.com/armosec/opa-utils/httpserver/meta/v1"
reporthandlingv2 "github.com/armosec/opa-utils/reporthandling/v2"
"github.com/armosec/utils-go/boolutils"
)
// executeScan execute the scan request passed in the channel
func (handler *HTTPHandler) executeScan() {
for {
scanReq := <-handler.scanRequestChan
response := &utilsmetav1.Response{}
logger.L().Info("scan triggered", helpers.String("ID", scanReq.scanID))
results, err := scan(scanReq.scanRequest, scanReq.scanID)
if err != nil {
logger.L().Error("scanning failed", helpers.String("ID", scanReq.scanID), helpers.Error(err))
if scanReq.scanQueryParams.ReturnResults {
response.Type = utilsapisv1.ErrorScanResponseType
response.Response = err.Error()
}
} else {
logger.L().Success("done scanning", helpers.String("ID", scanReq.scanID))
if scanReq.scanQueryParams.ReturnResults {
response.Type = utilsapisv1.ResultsV1ScanResponseType
response.Response = results
}
}
handler.state.setNotBusy(scanReq.scanID)
// return results
handler.scanResponseChan.push(scanReq.scanID, response)
}
}
func scan(scanRequest *utilsmetav1.PostScanRequest, scanID string) (*reporthandlingv2.PostureReport, error) {
scanInfo := getScanCommand(scanRequest, scanID)
@@ -136,7 +169,7 @@ func writeScanErrorToFile(err error, scanID string) error {
if e := os.MkdirAll(filepath.Dir(FailedOutputDir), os.ModePerm); e != nil {
return fmt.Errorf("failed to scan. reason: '%s'. failed to save error in file - failed to create directory. reason: %s", err.Error(), e.Error())
}
f, e := os.Create(filepath.Join(FailedOutputDir, scanID))
f, e := os.Create(filepath.Join(filepath.Dir(FailedOutputDir), scanID))
if e != nil {
return fmt.Errorf("failed to scan. reason: '%s'. failed to save error in file - failed to open file for writing. reason: %s", err.Error(), e.Error())
}
@@ -147,3 +180,9 @@ func writeScanErrorToFile(err error, scanID string) error {
}
return fmt.Errorf("failed to scan. reason: '%s'", err.Error())
}
// responseToBytes convert response object to bytes
func responseToBytes(res *utilsmetav1.Response) []byte {
b, _ := json.Marshal(res)
return b
}