diff --git a/pkg/auth/login.go b/pkg/auth/login.go index fa23917e..19e6625d 100644 --- a/pkg/auth/login.go +++ b/pkg/auth/login.go @@ -23,18 +23,20 @@ const ( pasteAnAuthenticationToken = "Paste an authentication token" ) -type tokenOrError struct { - token string - user string - err error +type paramsOrError struct { + token string + user string + organization string + err error } type Host struct { - Token string `yaml:"token"` - User string `yaml:"user"` + Token string `yaml:"token"` + User string `yaml:"user"` + Organization string `yaml:"organization"` } -var tokenOrErrorChan = make(chan tokenOrError) +var paramsOrErrorChan = make(chan paramsOrError) func HandleLogin() error { if _, err := os.Stat(polarisHostsFilepath); err == nil { @@ -69,13 +71,14 @@ func HandleLogin() error { return fmt.Errorf("asking how to authenticate: %w", err) } - var user, token string + var user, token, organization string if answer == loginUsingBrowser { listener, err := net.Listen("tcp", ":0") if err != nil { panic(err) } - err = openBrowser(fmt.Sprintf(insightsURL + registerPath + "?source=polaris&callbackUrl=" + fmt.Sprintf("http://localhost:%d/auth/login/callback", listener.Addr().(*net.TCPAddr).Port))) + localServerPort := listener.Addr().(*net.TCPAddr).Port + err = openBrowser(fmt.Sprintf(insightsURL + registerPath + "?source=polaris&callbackUrl=" + fmt.Sprintf("http://localhost:%d/auth/login/callback", localServerPort))) if err != nil { logrus.Fatal(err) } @@ -83,21 +86,22 @@ func HandleLogin() error { var router *mux.Router go func() { router = mux.NewRouter() - router.HandleFunc("/auth/login/callback", callbackHandler) + router.HandleFunc("/auth/login/callback", callbackHandler(localServerPort)) if err := http.Serve(listener, router); err != nil { - tokenOrErrorChan <- tokenOrError{err: fmt.Errorf("starting the local http server: %w", err)} + paramsOrErrorChan <- paramsOrError{err: fmt.Errorf("starting the local http server: %w", err)} } }() // wait the browser to callback the local server - tokenOrError := <-tokenOrErrorChan + paramOrError := <-paramsOrErrorChan - if tokenOrError.err != nil { - return tokenOrError.err + if paramOrError.err != nil { + return paramOrError.err } - token = tokenOrError.token - user = tokenOrError.user + token = paramOrError.token + user = paramOrError.user + organization = paramOrError.organization } else { var answer string err := survey.AskOne(&survey.Password{Message: "Paste your authentication token:"}, &answer, survey.WithValidator(validateToken)) @@ -105,7 +109,8 @@ func HandleLogin() error { return fmt.Errorf("asking how to authenticate: %w", err) } token = answer - user = "admin" // TODO: fetch name from bots endpoint + user = "admin" // TODO: Vitor - fetch name from bots endpoint + organization = "acme-co" // TODO: Vitor - fetch organization from bots endpoint } polarisCfgDir := filepath.Join(userHomeDir, ".config", "polaris") @@ -124,7 +129,7 @@ func HandleLogin() error { } }() - content := map[string]Host{insightsURL: {Token: token, User: user}} + content := map[string]Host{insightsURL: {Token: token, User: user, Organization: organization}} b, err := yaml.Marshal(content) if err != nil { return fmt.Errorf("marshalling yaml data: %w", err) @@ -138,24 +143,34 @@ func HandleLogin() error { logrus.Debugf("hosts file has been saved") fmt.Println("✓ Authentication complete.") - fmt.Printf("✓ Logged in as %s.\n", user) + fmt.Printf("✓ Logged in organization %s as %s.\n", organization, user) return nil } -func callbackHandler(w http.ResponseWriter, r *http.Request) { - token := r.URL.Query().Get("token") - if len(token) == 0 { - tokenOrErrorChan <- tokenOrError{err: errors.New("token query param is required in callback")} +func callbackHandler(localServerPort int) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("token") + var err error + if len(token) == 0 { + err = errors.New("token query param is required in callback") + } + user := r.URL.Query().Get("user") + if len(user) == 0 { + err = errors.New("user query param is required in callback") + } + organization := r.URL.Query().Get("organization") + if len(organization) == 0 { + err = errors.New("organization query param is required in callback") + } + if err != nil { + fmt.Fprintf(w, "unable to perform integration: %v", err) + paramsOrErrorChan <- paramsOrError{err: err} + return + } + fmt.Fprint(w, "integration finished successfully, you can safely close this tab now") + paramsOrErrorChan <- paramsOrError{token: token, user: user, organization: organization} return } - user := r.URL.Query().Get("user") - if len(user) == 0 { - tokenOrErrorChan <- tokenOrError{err: errors.New("user query param is required in callback")} - return - } - - tokenOrErrorChan <- tokenOrError{token: token, user: user} - return } func validateToken(args any) error {