From 823843f3846848172a3c96b4323574219fab6927 Mon Sep 17 00:00:00 2001 From: 6543 <6543@obermui.de> Date: Fri, 3 Apr 2026 10:50:43 +0200 Subject: [PATCH] Sanitize agent introduced pipeline/workflow/step state changes and log streaming (#6308) --- .golangci.yaml | 2 + server/model/agent.go | 3 +- server/rpc/auth_server.go | 2 +- server/rpc/auth_server_test.go | 285 +++++++++ server/rpc/authorizer.go | 2 +- server/rpc/authorizer_test.go | 336 ++++++++++ server/rpc/errors.go | 30 + server/rpc/filter.go | 2 +- server/rpc/filter_test.go | 2 +- server/rpc/jwt_manager.go | 2 +- server/rpc/jwt_manager_test.go | 279 ++++++++ server/rpc/rpc.go | 77 +-- server/rpc/rpc_integration_test.go | 994 +++++++++++++++++++++++++++++ server/rpc/rpc_test.go | 2 +- server/rpc/sanitize.go | 144 +++++ server/rpc/sanitize_test.go | 433 +++++++++++++ server/rpc/server.go | 2 +- 17 files changed, 2547 insertions(+), 50 deletions(-) create mode 100644 server/rpc/auth_server_test.go create mode 100644 server/rpc/authorizer_test.go create mode 100644 server/rpc/errors.go create mode 100644 server/rpc/jwt_manager_test.go create mode 100644 server/rpc/rpc_integration_test.go create mode 100644 server/rpc/sanitize.go create mode 100644 server/rpc/sanitize_test.go diff --git a/.golangci.yaml b/.golangci.yaml index d7ed36630..ce23bd90a 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -226,6 +226,8 @@ linters: alias: store_mocks - pkg: go.woodpecker-ci.org/woodpecker/v3/server/services/config/mocks alias: config_service_mocks + - pkg: go.woodpecker-ci.org/woodpecker/v3/server/services/log/mocks + alias: log_mocks # kubernetes - pkg: k8s.io/api/core/v1 diff --git a/server/model/agent.go b/server/model/agent.go index 8390e349e..196456708 100644 --- a/server/model/agent.go +++ b/server/model/agent.go @@ -52,7 +52,8 @@ func (Agent) TableName() string { } func (a *Agent) IsSystemAgent() bool { - return a.OwnerID == IDNotSet + return a.OwnerID == IDNotSet && + a.OrgID == IDNotSet } func GenerateNewAgentToken() string { diff --git a/server/rpc/auth_server.go b/server/rpc/auth_server.go index aef148c2b..ebbbc4003 100644 --- a/server/rpc/auth_server.go +++ b/server/rpc/auth_server.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package grpc +package rpc import ( "context" diff --git a/server/rpc/auth_server_test.go b/server/rpc/auth_server_test.go new file mode 100644 index 000000000..d73ee7552 --- /dev/null +++ b/server/rpc/auth_server_test.go @@ -0,0 +1,285 @@ +// Copyright 2026 Woodpecker Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rpc + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.woodpecker-ci.org/woodpecker/v3/rpc/proto" + "go.woodpecker-ci.org/woodpecker/v3/server/model" + store_mocks "go.woodpecker-ci.org/woodpecker/v3/server/store/mocks" + "go.woodpecker-ci.org/woodpecker/v3/server/store/types" +) + +// newAuthServer is a test helper that wires up a WoodpeckerAuthServer with the +// given master token and a mock store, then returns both so tests can set +// expectations before calling Auth / getAgent. +func newAuthServer(t *testing.T, masterToken string, store *store_mocks.MockStore) *WoodpeckerAuthServer { + t.Helper() + jwtManager := NewJWTManager("test-secret") + return NewWoodpeckerAuthServer(jwtManager, masterToken, store) +} + +func TestAuth(t *testing.T) { + t.Parallel() + + t.Run("master token with agentID=-1 creates new system agent and returns access token", func(t *testing.T) { + t.Parallel() + + store := store_mocks.NewMockStore(t) + store.On("AgentCreate", &model.Agent{ + OwnerID: model.IDNotSet, + OrgID: model.IDNotSet, + Token: "master-secret", + Capacity: -1, + }).Return(nil).Once() + + srv := newAuthServer(t, "master-secret", store) + resp, err := srv.Auth(t.Context(), &proto.AuthRequest{ + AgentId: -1, + AgentToken: "master-secret", + }) + + require.NoError(t, err) + assert.Equal(t, "ok", resp.Status) + assert.NotEmpty(t, resp.AccessToken) + // The newly created agent has ID 0 (zero-value) because AgentCreate + // doesn't set it in the mock – verify the token at least round-trips. + claims, verifyErr := NewJWTManager("test-secret").Verify(resp.AccessToken) + require.NoError(t, verifyErr) + assert.Equal(t, resp.AgentId, claims.AgentID) + }) + + t.Run("master token with existing agentID returns access token for that agent", func(t *testing.T) { + t.Parallel() + + existingAgent := &model.Agent{ + ID: 42, + OrgID: model.IDNotSet, // system agent + OwnerID: model.IDNotSet, + } + + store := store_mocks.NewMockStore(t) + store.On("AgentFind", int64(42)).Return(existingAgent, nil).Once() + + srv := newAuthServer(t, "master-secret", store) + resp, err := srv.Auth(t.Context(), &proto.AuthRequest{ + AgentId: 42, + AgentToken: "master-secret", + }) + + require.NoError(t, err) + assert.Equal(t, "ok", resp.Status) + assert.EqualValues(t, 42, resp.AgentId) + assert.NotEmpty(t, resp.AccessToken) + }) + + t.Run("individual agent token authenticates successfully", func(t *testing.T) { + t.Parallel() + + agent := &model.Agent{ID: 7, Token: "individual-token"} + + store := store_mocks.NewMockStore(t) + store.On("AgentFindByToken", "individual-token").Return(agent, nil).Once() + + // no master token configured + srv := newAuthServer(t, "", store) + resp, err := srv.Auth(t.Context(), &proto.AuthRequest{ + AgentId: 0, + AgentToken: "individual-token", + }) + + require.NoError(t, err) + assert.Equal(t, "ok", resp.Status) + assert.EqualValues(t, 7, resp.AgentId) + }) + + t.Run("bad token returns error", func(t *testing.T) { + t.Parallel() + + store := store_mocks.NewMockStore(t) + store.On("AgentFindByToken", "wrong-token"). + Return(nil, types.ErrRecordNotExist).Once() + + srv := newAuthServer(t, "", store) + _, err := srv.Auth(t.Context(), &proto.AuthRequest{ + AgentToken: "wrong-token", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "agent could not auth") + }) +} + +func TestGetAgent(t *testing.T) { + t.Parallel() + + t.Run("master token + agentID=-1 creates and returns a new system agent", func(t *testing.T) { + t.Parallel() + + store := store_mocks.NewMockStore(t) + store.On("AgentCreate", &model.Agent{ + OwnerID: model.IDNotSet, + OrgID: model.IDNotSet, + Token: "master", + Capacity: -1, + }).Return(nil).Once() + + srv := newAuthServer(t, "master", store) + agent, err := srv.getAgent(-1, "master") + + require.NoError(t, err) + require.NotNil(t, agent) + assert.Equal(t, "master", agent.Token) + assert.EqualValues(t, model.IDNotSet, agent.OrgID) + }) + + t.Run("master token + agentID=-1 propagates AgentCreate error", func(t *testing.T) { + t.Parallel() + + store := store_mocks.NewMockStore(t) + store.On("AgentCreate", &model.Agent{ + OwnerID: model.IDNotSet, + OrgID: model.IDNotSet, + Token: "master", + Capacity: -1, + }).Return(errors.New("db error")).Once() + + srv := newAuthServer(t, "master", store) + _, err := srv.getAgent(-1, "master") + + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("master token + existing agentID returns the stored agent", func(t *testing.T) { + t.Parallel() + + systemAgent := &model.Agent{ID: 99, OrgID: model.IDNotSet, OwnerID: model.IDNotSet} + + store := store_mocks.NewMockStore(t) + store.On("AgentFind", int64(99)).Return(systemAgent, nil).Once() + + srv := newAuthServer(t, "master", store) + agent, err := srv.getAgent(99, "master") + + require.NoError(t, err) + assert.Equal(t, int64(99), agent.ID) + }) + + t.Run("master token + agentID not found in database returns error", func(t *testing.T) { + t.Parallel() + + store := store_mocks.NewMockStore(t) + store.On("AgentFind", int64(404)).Return(nil, types.ErrRecordNotExist).Once() + + srv := newAuthServer(t, "master", store) + _, err := srv.getAgent(404, "master") + + require.Error(t, err) + assert.Contains(t, err.Error(), "AgentID not found in database") + }) + + t.Run("master token + agentID store returns unexpected error is propagated", func(t *testing.T) { + t.Parallel() + + store := store_mocks.NewMockStore(t) + store.On("AgentFind", int64(1)).Return(nil, errors.New("connection reset")).Once() + + srv := newAuthServer(t, "master", store) + _, err := srv.getAgent(1, "master") + + require.Error(t, err) + assert.Contains(t, err.Error(), "connection reset") + }) + + t.Run("master token + agentID that is not a system agent returns error", func(t *testing.T) { + t.Parallel() + + // An agent with a non-IDNotSet OrgID is not a system agent. + orgAgent := &model.Agent{ID: 5, OrgID: 100, OwnerID: model.IDNotSet} + + store := store_mocks.NewMockStore(t) + store.On("AgentFind", int64(5)).Return(orgAgent, nil).Once() + + srv := newAuthServer(t, "master", store) + _, err := srv.getAgent(5, "master") + + require.Error(t, err) + assert.Contains(t, err.Error(), "not a system agent") + }) + + t.Run("individual token auth succeeds when token is found", func(t *testing.T) { + t.Parallel() + + agent := &model.Agent{ID: 3, Token: "ind-token"} + store := store_mocks.NewMockStore(t) + store.On("AgentFindByToken", "ind-token").Return(agent, nil).Once() + + // No master token set – falls straight to individual auth. + srv := newAuthServer(t, "", store) + got, err := srv.getAgent(0, "ind-token") + + require.NoError(t, err) + assert.Equal(t, int64(3), got.ID) + }) + + t.Run("individual token not found returns wrapped error", func(t *testing.T) { + t.Parallel() + + store := store_mocks.NewMockStore(t) + store.On("AgentFindByToken", "bad-token"). + Return(nil, types.ErrRecordNotExist).Once() + + srv := newAuthServer(t, "", store) + _, err := srv.getAgent(0, "bad-token") + + require.Error(t, err) + assert.Contains(t, err.Error(), "individual agent not found by token") + }) + + t.Run("individual token store returns unexpected error is propagated", func(t *testing.T) { + t.Parallel() + + store := store_mocks.NewMockStore(t) + store.On("AgentFindByToken", "token"). + Return(nil, errors.New("timeout")).Once() + + srv := newAuthServer(t, "", store) + _, err := srv.getAgent(0, "token") + + require.Error(t, err) + assert.Contains(t, err.Error(), "timeout") + }) + + t.Run("master token configured but wrong token falls through to individual auth", func(t *testing.T) { + t.Parallel() + + agent := &model.Agent{ID: 8, Token: "ind-token"} + store := store_mocks.NewMockStore(t) + // master token is "master" but caller sends "ind-token" → individual path + store.On("AgentFindByToken", "ind-token").Return(agent, nil).Once() + + srv := newAuthServer(t, "master", store) + got, err := srv.getAgent(0, "ind-token") + + require.NoError(t, err) + assert.Equal(t, int64(8), got.ID) + }) +} diff --git a/server/rpc/authorizer.go b/server/rpc/authorizer.go index bdf21d432..6fb2c7021 100644 --- a/server/rpc/authorizer.go +++ b/server/rpc/authorizer.go @@ -44,7 +44,7 @@ // resp, _ := authClient.Auth(ctx, &proto.AuthRequest{AgentToken: "secret", AgentId: -1}) // ctx = metadata.AppendToOutgoingContext(ctx, "token", resp.AccessToken) // workflow, _ := woodpeckerClient.Next(ctx, &proto.NextRequest{...}) -package grpc +package rpc import ( "context" diff --git a/server/rpc/authorizer_test.go b/server/rpc/authorizer_test.go new file mode 100644 index 000000000..4be71e046 --- /dev/null +++ b/server/rpc/authorizer_test.go @@ -0,0 +1,336 @@ +// Copyright 2026 Woodpecker Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rpc + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +func newAuthorizer(t *testing.T) *Authorizer { + t.Helper() + return NewAuthorizer(NewJWTManager("auth-test-secret")) +} + +// validTokenForAgent generates a JWT that the authorizer will accept. +func validTokenForAgent(t *testing.T, agentID int64) string { + t.Helper() + token, err := NewJWTManager("auth-test-secret").Generate(agentID) + require.NoError(t, err) + return token +} + +// ctxWithToken builds an incoming gRPC context carrying metadata["token"]. +func ctxWithToken(ctx context.Context, token string) context.Context { + return metadata.NewIncomingContext(ctx, metadata.Pairs("token", token)) +} + +func TestAuthorize(t *testing.T) { + t.Parallel() + + t.Run("Auth endpoint bypasses JWT validation", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + // Plain context with no metadata – would normally fail, but Auth is exempt. + ctx, err := a.authorize(t.Context(), "/proto.WoodpeckerAuth/Auth") + + require.NoError(t, err) + assert.NotNil(t, ctx) + }) + + t.Run("missing metadata returns Unauthenticated", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + // A plain context has no gRPC incoming metadata. + _, err := a.authorize(t.Context(), "/proto.WoodpeckerServer/Next") + + require.Error(t, err) + s, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unauthenticated, s.Code()) + assert.Contains(t, s.Message(), "metadata is not provided") + }) + + t.Run("metadata present but token key absent returns Unauthenticated", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("other-key", "value")) + + _, err := a.authorize(ctx, "/proto.WoodpeckerServer/Next") + + require.Error(t, err) + s, _ := status.FromError(err) + assert.Equal(t, codes.Unauthenticated, s.Code()) + assert.Contains(t, s.Message(), "token is not provided") + }) + + t.Run("invalid (garbage) token returns Unauthenticated", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + ctx := ctxWithToken(t.Context(), "this-is-not-a-jwt") + + _, err := a.authorize(ctx, "/proto.WoodpeckerServer/Next") + + require.Error(t, err) + s, _ := status.FromError(err) + assert.Equal(t, codes.Unauthenticated, s.Code()) + assert.Contains(t, s.Message(), "access token is invalid") + }) + + t.Run("token signed with wrong secret returns Unauthenticated", func(t *testing.T) { + t.Parallel() + + wrongManager := NewJWTManager("DIFFERENT-secret") + token, err := wrongManager.Generate(55) + require.NoError(t, err) + + a := newAuthorizer(t) // uses "auth-test-secret" + ctx := ctxWithToken(t.Context(), token) + + _, err = a.authorize(ctx, "/proto.WoodpeckerServer/Next") + + require.Error(t, err) + s, _ := status.FromError(err) + assert.Equal(t, codes.Unauthenticated, s.Code()) + }) + + t.Run("valid token enriches context with agent_id metadata", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + token := validTokenForAgent(t, 77) + ctx := ctxWithToken(t.Context(), token) + + newCtx, err := a.authorize(ctx, "/proto.WoodpeckerServer/Next") + + require.NoError(t, err) + + md, ok := metadata.FromIncomingContext(newCtx) + require.True(t, ok) + agentIDs := md["agent_id"] + require.Len(t, agentIDs, 1) + assert.Equal(t, "77", agentIDs[0]) + }) + + t.Run("valid token preserves existing metadata keys", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + token := validTokenForAgent(t, 10) + ctx := metadata.NewIncomingContext(t.Context(), + metadata.Pairs("token", token, "hostname", "worker-1"), + ) + + newCtx, err := a.authorize(ctx, "/proto.WoodpeckerServer/Init") + + require.NoError(t, err) + md, _ := metadata.FromIncomingContext(newCtx) + assert.Equal(t, []string{"worker-1"}, md["hostname"]) + assert.Equal(t, []string{"10"}, md["agent_id"]) + }) + + t.Run("empty token value in metadata slice returns Unauthenticated", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + // Passing an empty string as the token value. + ctx := ctxWithToken(t.Context(), "") + + _, err := a.authorize(ctx, "/proto.WoodpeckerServer/Next") + + require.Error(t, err) + s, _ := status.FromError(err) + assert.Equal(t, codes.Unauthenticated, s.Code()) + }) +} + +func TestUnaryInterceptor(t *testing.T) { + t.Parallel() + + t.Run("valid token calls handler with enriched context", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + token := validTokenForAgent(t, 21) + ctx := ctxWithToken(t.Context(), token) + + var capturedCtx context.Context + handler := func(ctx context.Context, _ any) (any, error) { + capturedCtx = ctx + return "ok", nil + } + + resp, err := a.UnaryInterceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: "/proto.WoodpeckerServer/Next", + }, handler) + + require.NoError(t, err) + assert.Equal(t, "ok", resp) + + md, ok := metadata.FromIncomingContext(capturedCtx) + require.True(t, ok) + assert.Equal(t, []string{"21"}, md["agent_id"]) + }) + + t.Run("invalid token does not call handler", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + ctx := ctxWithToken(t.Context(), "bad-token") + + handlerCalled := false + handler := func(_ context.Context, _ any) (any, error) { + handlerCalled = true + return nil, nil + } + + _, err := a.UnaryInterceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: "/proto.WoodpeckerServer/Next", + }, handler) + + require.Error(t, err) + assert.False(t, handlerCalled) + }) + + t.Run("Auth endpoint bypasses token check and calls handler", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + // No token in context – fine because Auth is exempt. + ctx := metadata.NewIncomingContext(t.Context(), metadata.MD{}) + + handlerCalled := false + handler := func(_ context.Context, _ any) (any, error) { + handlerCalled = true + return nil, nil + } + + _, err := a.UnaryInterceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: "/proto.WoodpeckerAuth/Auth", + }, handler) + + require.NoError(t, err) + assert.True(t, handlerCalled) + }) + + t.Run("handler error is propagated", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + token := validTokenForAgent(t, 1) + ctx := ctxWithToken(t.Context(), token) + + handler := func(_ context.Context, _ any) (any, error) { + return nil, errors.New("handler boom") + } + + _, err := a.UnaryInterceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: "/proto.WoodpeckerServer/Next", + }, handler) + + require.Error(t, err) + assert.Contains(t, err.Error(), "handler boom") + }) +} + +// mockServerStream is a minimal grpc.ServerStream for testing. +type mockServerStream struct { + ctx context.Context +} + +func (m *mockServerStream) SetHeader(metadata.MD) error { return nil } +func (m *mockServerStream) SendHeader(metadata.MD) error { return nil } +func (m *mockServerStream) SetTrailer(metadata.MD) {} +func (m *mockServerStream) Context() context.Context { return m.ctx } +func (m *mockServerStream) SendMsg(any) error { return nil } +func (m *mockServerStream) RecvMsg(any) error { return nil } + +func TestStreamInterceptor(t *testing.T) { + t.Parallel() + + t.Run("valid token calls handler with enriched stream context", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + token := validTokenForAgent(t, 33) + ctx := ctxWithToken(t.Context(), token) + stream := &mockServerStream{ctx: ctx} + + var capturedStream grpc.ServerStream + handler := func(_ any, s grpc.ServerStream) error { + capturedStream = s + return nil + } + + err := a.StreamInterceptor(nil, stream, &grpc.StreamServerInfo{ + FullMethod: "/proto.WoodpeckerServer/Next", + }, handler) + + require.NoError(t, err) + + md, ok := metadata.FromIncomingContext(capturedStream.Context()) + require.True(t, ok) + assert.Equal(t, []string{"33"}, md["agent_id"]) + }) + + t.Run("invalid token does not call handler", func(t *testing.T) { + t.Parallel() + + a := newAuthorizer(t) + ctx := ctxWithToken(t.Context(), "garbage") + stream := &mockServerStream{ctx: ctx} + + handlerCalled := false + handler := func(_ any, _ grpc.ServerStream) error { + handlerCalled = true + return nil + } + + err := a.StreamInterceptor(nil, stream, &grpc.StreamServerInfo{ + FullMethod: "/proto.WoodpeckerServer/Next", + }, handler) + + require.Error(t, err) + assert.False(t, handlerCalled) + s, _ := status.FromError(err) + assert.Equal(t, codes.Unauthenticated, s.Code()) + }) + + t.Run("stream context wrapper SetContext and Context round-trip", func(t *testing.T) { + t.Parallel() + + stream := &mockServerStream{ctx: t.Context()} + wrapper := newStreamContextWrapper(stream) + + newCtx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("foo", "bar")) + wrapper.SetContext(newCtx) + + md, ok := metadata.FromIncomingContext(wrapper.Context()) + require.True(t, ok) + assert.Equal(t, []string{"bar"}, md["foo"]) + }) +} diff --git a/server/rpc/errors.go b/server/rpc/errors.go new file mode 100644 index 000000000..2065904a7 --- /dev/null +++ b/server/rpc/errors.go @@ -0,0 +1,30 @@ +// Copyright 2026 Woodpecker Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rpc + +import "errors" + +var ( + ErrAgentIllegalPipelineWorkflowReRunStateChange = errors.New("workflow has parent pipeline marked as finished") + ErrAgentIllegalPipelineWorkflowRun = errors.New("workflow has parent pipeline in blocked state") + + ErrAgentIllegalWorkflowReRunStateChange = errors.New("workflow was already marked as finished") + ErrAgentIllegalWorkflowRun = errors.New("workflow is currently in blocked state") + + ErrAgentIllegalStepReRunStateChange = errors.New("step was already marked as finished") + ErrAgentIllegalStepRun = errors.New("step is currently in blocked state") + + ErrAgentIllegalLogStreaming = errors.New("agent can not append logs to a step that is marked not running") +) diff --git a/server/rpc/filter.go b/server/rpc/filter.go index 5d75cd71d..db1c5a2e8 100644 --- a/server/rpc/filter.go +++ b/server/rpc/filter.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package grpc +package rpc import ( "maps" diff --git a/server/rpc/filter_test.go b/server/rpc/filter_test.go index 04d0e13dc..fb9b8f4ab 100644 --- a/server/rpc/filter_test.go +++ b/server/rpc/filter_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package grpc +package rpc import ( "testing" diff --git a/server/rpc/jwt_manager.go b/server/rpc/jwt_manager.go index 8cf7cbc36..88a7c20f9 100644 --- a/server/rpc/jwt_manager.go +++ b/server/rpc/jwt_manager.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package grpc +package rpc import ( "errors" diff --git a/server/rpc/jwt_manager_test.go b/server/rpc/jwt_manager_test.go new file mode 100644 index 000000000..2de1e993b --- /dev/null +++ b/server/rpc/jwt_manager_test.go @@ -0,0 +1,279 @@ +// Copyright 2026 Woodpecker Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rpc + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJWTManager(t *testing.T) { + t.Parallel() + + t.Run("generate and verify roundtrip", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + token, err := manager.Generate(42) + require.NoError(t, err) + assert.NotEmpty(t, token) + + claims, err := manager.Verify(token) + require.NoError(t, err) + assert.Equal(t, int64(42), claims.AgentID) + }) + + t.Run("claims contain correct fields", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + token, err := manager.Generate(99) + require.NoError(t, err) + + claims, err := manager.Verify(token) + require.NoError(t, err) + + assert.Equal(t, int64(99), claims.AgentID) + assert.Equal(t, "woodpecker", claims.Issuer) + assert.Equal(t, fmt.Sprintf("%d", 99), claims.Subject) + assert.Equal(t, fmt.Sprintf("%d", 99), claims.ID) + }) + + t.Run("different agent IDs produce different tokens", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + token1, err := manager.Generate(1) + require.NoError(t, err) + + token2, err := manager.Generate(2) + require.NoError(t, err) + + assert.NotEqual(t, token1, token2) + }) + + t.Run("expired token is rejected", func(t *testing.T) { + t.Parallel() + + manager := &JWTManager{ + secretKey: "test-secret", + tokenDuration: 1 * time.Millisecond, + } + + token, err := manager.Generate(42) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + _, err = manager.Verify(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid token") + }) + + t.Run("wrong signing secret rejects token", func(t *testing.T) { + t.Parallel() + + managerA := NewJWTManager("secret-A") + managerB := NewJWTManager("secret-B") + + token, err := managerA.Generate(42) + require.NoError(t, err) + + _, err = managerB.Verify(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid token") + }) + + t.Run("tampered token is rejected", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + token, err := manager.Generate(42) + require.NoError(t, err) + + // flip a character in the signature portion + tampered := token[:len(token)-1] + "X" + + _, err = manager.Verify(tampered) + assert.Error(t, err) + }) + + t.Run("empty token is rejected", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + _, err := manager.Verify("") + assert.Error(t, err) + }) + + t.Run("garbage token is rejected", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + _, err := manager.Verify("this-is-not-a-jwt") + assert.Error(t, err) + }) + + t.Run("token generated with negative agent ID", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + token, err := manager.Generate(-1) + require.NoError(t, err) + + claims, err := manager.Verify(token) + require.NoError(t, err) + assert.Equal(t, int64(-1), claims.AgentID) + }) +} + +// buildUnsignedToken manually constructs a JWT with alg=none so we can verify +// that Verify() rejects it even though the signature section is empty. +// We do NOT use the golang-jwt library here because modern versions refuse to +// produce none-signed tokens — that is exactly the property we want to test. +func buildUnsignedToken(t *testing.T, agentID int64) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString( + jwtMustMarshal(t, map[string]string{"alg": "none", "typ": "JWT"}), + ) + payload := base64.RawURLEncoding.EncodeToString( + jwtMustMarshal(t, map[string]any{ + "agent_id": agentID, + "iss": "woodpecker", + }), + ) + // A none-signed JWT carries an empty signature segment. + return header + "." + payload + "." +} + +// buildRS256FakeToken constructs a JWT header claiming RS256 to exercise the +// unexpected-signing-method guard inside JWTManager.Verify(). +func buildRS256FakeToken(t *testing.T) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString( + jwtMustMarshal(t, map[string]string{"alg": "RS256", "typ": "JWT"}), + ) + payload := base64.RawURLEncoding.EncodeToString( + jwtMustMarshal(t, map[string]any{"agent_id": 1, "iss": "woodpecker"}), + ) + sig := base64.RawURLEncoding.EncodeToString([]byte("fake-rsa-sig")) + return header + "." + payload + "." + sig +} + +// buildFutureNbfToken constructs a JWT whose nbf claim is set far in the +// future. The token must be rejected regardless of which check fires first. +func buildFutureNbfToken(t *testing.T) string { + t.Helper() + const farFuture = int64(9_999_999_999) // year 2286 + header := base64.RawURLEncoding.EncodeToString( + jwtMustMarshal(t, map[string]string{"alg": "HS256", "typ": "JWT"}), + ) + payload := base64.RawURLEncoding.EncodeToString( + jwtMustMarshal(t, map[string]any{ + "agent_id": 1, + "iss": "woodpecker", + "nbf": farFuture, + "exp": farFuture + 3600, + }), + ) + badSig := base64.RawURLEncoding.EncodeToString([]byte("bad")) + return header + "." + payload + "." + badSig +} + +func jwtMustMarshal(t *testing.T, v any) []byte { + t.Helper() + b, err := json.Marshal(v) + require.NoError(t, err) + return b +} + +func TestJWTManagerAdditional(t *testing.T) { + t.Parallel() + + t.Run("none-algorithm token is rejected", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + noneToken := buildUnsignedToken(t, 42) + + // Sanity: token really does carry the none algorithm header. + parts := strings.Split(noneToken, ".") + require.Len(t, parts, 3) + assert.Equal(t, "", parts[2], "signature part must be empty for none-alg tokens") + + _, err := manager.Verify(noneToken) + assert.Error(t, err, "verifier must reject a none-algorithm token") + assert.Contains(t, err.Error(), "invalid token") + }) + + t.Run("RS256 token (unexpected signing method) is rejected", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + rs256Token := buildRS256FakeToken(t) + + _, err := manager.Verify(rs256Token) + assert.Error(t, err, "verifier must reject tokens with an unexpected signing method") + assert.Contains(t, err.Error(), "invalid token") + }) + + t.Run("token with far-future NotBefore is rejected", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + futureToken := buildFutureNbfToken(t) + + _, err := manager.Verify(futureToken) + assert.Error(t, err) + }) + + t.Run("two valid tokens for same agent are each independently verifiable", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + + tok1, err := manager.Generate(5) + require.NoError(t, err) + tok2, err := manager.Generate(5) + require.NoError(t, err) + + claims1, err := manager.Verify(tok1) + require.NoError(t, err) + assert.Equal(t, int64(5), claims1.AgentID) + + claims2, err := manager.Verify(tok2) + require.NoError(t, err) + assert.Equal(t, int64(5), claims2.AgentID) + }) + + t.Run("zero agent ID is preserved through generate/verify roundtrip", func(t *testing.T) { + t.Parallel() + + manager := NewJWTManager("test-secret") + token, err := manager.Generate(0) + require.NoError(t, err) + + claims, err := manager.Verify(token) + require.NoError(t, err) + assert.Equal(t, int64(0), claims.AgentID) + }) +} diff --git a/server/rpc/rpc.go b/server/rpc/rpc.go index 2ed19cf87..0bf0e2c63 100644 --- a/server/rpc/rpc.go +++ b/server/rpc/rpc.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package grpc +package rpc import ( "context" @@ -199,6 +199,14 @@ func (s *RPC) Update(c context.Context, strWorkflowID string, state rpc.StepStat return err } + // sanitize agent input + if err := checkPipelineState(currentPipeline); err != nil { + return err + } + if err := checkWorkflowStepStates(workflow, step); err != nil { + return err + } + if err := pipeline.UpdateStepStatus(c, s.store, step, state); err != nil { log.Error().Err(err).Msg("rpc.update: cannot update step") } @@ -252,6 +260,14 @@ func (s *RPC) Init(c context.Context, strWorkflowID string, state rpc.WorkflowSt return err } + // sanitize agent input + if err := checkPipelineState(currentPipeline); err != nil { + return err + } + if err := checkWorkflowStepStates(workflow, nil); err != nil { + return err + } + if currentPipeline.Status == model.StatusPending { if currentPipeline, err = pipeline.UpdateToStatusRunning(s.store, *currentPipeline, state.Started); err != nil { log.Error().Err(err).Msgf("init: cannot update pipeline %d state", currentPipeline.ID) @@ -317,6 +333,14 @@ func (s *RPC) Done(c context.Context, strWorkflowID string, state rpc.WorkflowSt return err } + // sanitize agent input + if err := checkPipelineState(currentPipeline); err != nil { + return err + } + if err := checkWorkflowStepStates(workflow, nil); err != nil { + return err + } + logger := log.With(). Str("repo_id", fmt.Sprint(repo.ID)). Str("pipeline_id", fmt.Sprint(currentPipeline.ID)). @@ -367,8 +391,10 @@ func (s *RPC) Done(c context.Context, strWorkflowID string, state rpc.WorkflowSt // make sure writes to pubsub are non blocking (https://github.com/woodpecker-ci/woodpecker/blob/c919f32e0b6432a95e1a6d3d0ad662f591adf73f/server/logging/log.go#L9) go func() { for _, step := range workflow.Children { - if err := s.logger.Close(c, step.ID); err != nil { - logger.Error().Err(err).Msgf("done: cannot close log stream for step %d", step.ID) + if step.State != model.StatusSkipped { + if err := s.logger.Close(c, step.ID); err != nil { + logger.Error().Err(err).Msgf("done: cannot close log stream for step %d", step.ID) + } } } }() @@ -412,6 +438,11 @@ func (s *RPC) Log(c context.Context, stepUUID string, rpcLogEntries []*rpc.LogEn return err } + // sanitize agent input + if err := allowAppendingLogs(currentPipeline, step); err != nil { + return fmt.Errorf("can not alter logs: %w", err) + } + err = s.updateAgentLastWork(agent) if err != nil { return err @@ -506,48 +537,10 @@ func (s *RPC) ReportHealth(ctx context.Context, status string) error { return s.store.AgentUpdate(agent) } -func (s *RPC) checkAgentPermissionByWorkflow(_ context.Context, agent *model.Agent, strWorkflowID string, pipeline *model.Pipeline, repo *model.Repo) error { - var err error - if repo == nil && pipeline == nil { - workflowID, err := strconv.ParseInt(strWorkflowID, 10, 64) - if err != nil { - return err - } - - workflow, err := s.store.WorkflowLoad(workflowID) - if err != nil { - log.Error().Err(err).Msgf("cannot find workflow with id %d", workflowID) - return err - } - - pipeline, err = s.store.GetPipeline(workflow.PipelineID) - if err != nil { - log.Error().Err(err).Msgf("cannot find pipeline with id %d", workflow.PipelineID) - return err - } - } - - if repo == nil { - repo, err = s.store.GetRepo(pipeline.RepoID) - if err != nil { - log.Error().Err(err).Msgf("cannot find repo with id %d", pipeline.RepoID) - return err - } - } - - if agent.CanAccessRepo(repo) { - return nil - } - - msg := fmt.Sprintf("agent '%d' is not allowed to interact with repo[%d] '%s'", agent.ID, repo.ID, repo.FullName) - log.Error().Int64("repoId", repo.ID).Msg(msg) - return errors.New(msg) -} - func (s *RPC) completeChildrenIfParentCompleted(completedWorkflow *model.Workflow, finished int64) { for _, c := range completedWorkflow.Children { if c.Running() { - if updated, err := pipeline.UpdateStepToStatusSkipped(s.store, *c, finished, model.StatusSkipped); err != nil { + if updated, err := pipeline.UpdateStepToStatusSkipped(s.store, *c, finished, model.StatusKilled); err != nil { log.Error().Err(err).Msgf("done: cannot update step_id %d child state", c.ID) } else { // Update in-memory state so WorkflowStatus sees the final state diff --git a/server/rpc/rpc_integration_test.go b/server/rpc/rpc_integration_test.go new file mode 100644 index 000000000..f05204eca --- /dev/null +++ b/server/rpc/rpc_integration_test.go @@ -0,0 +1,994 @@ +// Copyright 2026 Woodpecker Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rpc + +import ( + "errors" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + + "go.woodpecker-ci.org/woodpecker/v3/rpc" + "go.woodpecker-ci.org/woodpecker/v3/server" + "go.woodpecker-ci.org/woodpecker/v3/server/logging" + "go.woodpecker-ci.org/woodpecker/v3/server/model" + "go.woodpecker-ci.org/woodpecker/v3/server/pubsub/memory" + queue_mocks "go.woodpecker-ci.org/woodpecker/v3/server/queue/mocks" + log_mocks "go.woodpecker-ci.org/woodpecker/v3/server/services/log/mocks" + store_mocks "go.woodpecker-ci.org/woodpecker/v3/server/store/mocks" +) + +// newTestRPC creates an RPC instance with common test infrastructure. +func newTestRPC(t *testing.T, mockStore *store_mocks.MockStore) RPC { + t.Helper() + + pipelineTime := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "woodpecker_test", + Name: "pipeline_time_" + t.Name(), + }, []string{"repo", "branch", "status", "pipeline"}) + pipelineCount := prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "woodpecker_test", + Name: "pipeline_count_" + t.Name(), + }, []string{"repo", "branch", "status", "pipeline"}) + + return RPC{ + store: mockStore, + pubsub: memory.New(), + logger: logging.New(), + pipelineTime: pipelineTime, + pipelineCount: pipelineCount, + } +} + +// defaultAgent returns a system agent (OrgID=-1) that can access any repo. +func defaultAgent() *model.Agent { + return &model.Agent{ + ID: 1, + Name: "test-agent", + OrgID: model.IDNotSet, + } +} + +// orgAgent999 returns an agent scoped to a specific org. +func orgAgent999() *model.Agent { + return &model.Agent{ + ID: 2, + Name: "org-agent", + OrgID: 999, + } +} + +func defaultRepo() *model.Repo { + return &model.Repo{ + ID: 10, + OrgID: 100, + FullName: "test-org/test-repo", + } +} + +func defaultPipeline(status model.StatusValue) *model.Pipeline { + return &model.Pipeline{ + ID: 20, + RepoID: 10, + Status: status, + Branch: "main", + } +} + +func defaultWorkflow(state model.StatusValue) *model.Workflow { + return &model.Workflow{ + ID: 30, + PipelineID: 20, + State: state, + Name: "test-workflow", + } +} + +func defaultStep(state model.StatusValue) *model.Step { + return &model.Step{ + ID: 40, + UUID: "step-uuid-123", + PipelineID: 20, + State: state, + } +} + +func TestRPCUpdate(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockLogStore := log_mocks.NewMockService(t) + origLogStore := server.Config.Services.LogStore + server.Config.Services.LogStore = mockLogStore + t.Cleanup(func() { server.Config.Services.LogStore = origLogStore }) + + agent := defaultAgent() + repo := defaultRepo() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusRunning) + step := defaultStep(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("GetRepo", int64(10)).Return(repo, nil) + // pipeline.UpdateStepStatus calls StepUpdate + mockStore.On("StepUpdate", mock.Anything).Return(nil) + mockStore.On("WorkflowGetTree", mock.Anything).Return([]*model.Workflow{workflow}, nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{ + StepUUID: "step-uuid-123", + Started: 100, + Exited: false, + }) + assert.NoError(t, err) + }) + + t.Run("reject pipeline already succeeded", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusSuccess) + workflow := defaultWorkflow(model.StatusRunning) + step := defaultStep(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "step-uuid-123"}) + assert.ErrorIs(t, err, ErrAgentIllegalPipelineWorkflowReRunStateChange) + }) + + t.Run("reject pipeline already failed", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusFailure) + workflow := defaultWorkflow(model.StatusRunning) + step := defaultStep(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "step-uuid-123"}) + assert.ErrorIs(t, err, ErrAgentIllegalPipelineWorkflowReRunStateChange) + }) + + t.Run("reject pipeline blocked", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusBlocked) + workflow := defaultWorkflow(model.StatusRunning) + step := defaultStep(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "step-uuid-123"}) + assert.ErrorIs(t, err, ErrAgentIllegalPipelineWorkflowRun) + }) + + t.Run("reject workflow already finished", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusSuccess) // finished + step := defaultStep(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "step-uuid-123"}) + assert.ErrorIs(t, err, ErrAgentIllegalWorkflowReRunStateChange) + }) + + t.Run("reject step already finished", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusRunning) + step := defaultStep(model.StatusSuccess) // finished + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "step-uuid-123"}) + assert.ErrorIs(t, err, ErrAgentIllegalStepReRunStateChange) + }) + + t.Run("reject step belongs to different pipeline", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusRunning) + step := &model.Step{ + ID: 40, + UUID: "step-uuid-123", + PipelineID: 999, // different pipeline! + State: model.StatusRunning, + } + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "step-uuid-123"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "does not belong to current pipeline") + }) + + t.Run("reject agent from wrong org", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := orgAgent999() + repo := defaultRepo() // org 100 + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusRunning) + step := defaultStep(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(2)).Return(agent, nil) + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("GetRepo", int64(10)).Return(repo, nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "2")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "step-uuid-123"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to interact") + }) + + t.Run("reject invalid workflow ID", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "not-a-number", rpc.StepState{StepUUID: "step-uuid-123"}) + assert.Error(t, err) + }) + + t.Run("reject nonexistent workflow", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockStore.On("WorkflowLoad", int64(999)).Return(nil, errors.New("not found")) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "999", rpc.StepState{StepUUID: "step-uuid-123"}) + assert.Error(t, err) + }) + + t.Run("reject nonexistent step UUID", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("StepByUUID", "nonexistent").Return(nil, errors.New("not found")) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "nonexistent"}) + assert.Error(t, err) + }) + + t.Run("reject missing agent metadata", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + + rpcInst := newTestRPC(t, mockStore) + // no agent_id in metadata + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs()) + + err := rpcInst.Update(ctx, "30", rpc.StepState{StepUUID: "step-uuid-123"}) + assert.Error(t, err) + }) +} + +func TestRPCInit(t *testing.T) { + t.Run("happy path - pending pipeline", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + repo := defaultRepo() + pipeline := defaultPipeline(model.StatusPending) + workflow := defaultWorkflow(model.StatusPending) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(repo, nil) + // pipeline.UpdateToStatusRunning -> UpdatePipeline + mockStore.On("UpdatePipeline", mock.Anything).Return(nil) + // updateForgeStatus -> GetUser returns error so forge interaction is skipped + mockStore.On("GetUser", mock.Anything).Return(nil, errors.New("user not found")) + // pipeline.UpdateWorkflowStatusToRunning -> WorkflowUpdate + mockStore.On("WorkflowUpdate", mock.Anything).Return(nil) + // pubsub deferred -> WorkflowGetTree + mockStore.On("WorkflowGetTree", mock.Anything).Return([]*model.Workflow{workflow}, nil) + // updateAgentLastWork -> AgentUpdate + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Init(ctx, "30", rpc.WorkflowState{Started: 100}) + assert.NoError(t, err) + }) + + t.Run("happy path - already running pipeline", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + repo := defaultRepo() + pipeline := defaultPipeline(model.StatusRunning) // another workflow already started it + workflow := defaultWorkflow(model.StatusPending) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(repo, nil) + // updateForgeStatus -> GetUser returns error so forge interaction is skipped + mockStore.On("GetUser", mock.Anything).Return(nil, errors.New("user not found")) + mockStore.On("WorkflowUpdate", mock.Anything).Return(nil) + mockStore.On("WorkflowGetTree", mock.Anything).Return([]*model.Workflow{workflow}, nil) + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Init(ctx, "30", rpc.WorkflowState{Started: 100}) + assert.NoError(t, err) + }) + + t.Run("reject pipeline already succeeded", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusSuccess) + workflow := defaultWorkflow(model.StatusPending) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Init(ctx, "30", rpc.WorkflowState{Started: 100}) + assert.ErrorIs(t, err, ErrAgentIllegalPipelineWorkflowReRunStateChange) + }) + + t.Run("reject pipeline blocked", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusBlocked) + workflow := defaultWorkflow(model.StatusPending) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Init(ctx, "30", rpc.WorkflowState{Started: 100}) + assert.ErrorIs(t, err, ErrAgentIllegalPipelineWorkflowRun) + }) + + t.Run("reject workflow already finished", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusSuccess) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Init(ctx, "30", rpc.WorkflowState{Started: 100}) + assert.ErrorIs(t, err, ErrAgentIllegalWorkflowReRunStateChange) + }) + + t.Run("reject workflow blocked", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusBlocked) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Init(ctx, "30", rpc.WorkflowState{Started: 100}) + assert.ErrorIs(t, err, ErrAgentIllegalWorkflowRun) + }) + + t.Run("reject agent wrong org", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := orgAgent999() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusPending) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("AgentFind", int64(2)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "2")) + + err := rpcInst.Init(ctx, "30", rpc.WorkflowState{Started: 100}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to interact") + }) + + t.Run("reject invalid workflow ID", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Init(ctx, "not-a-number", rpc.WorkflowState{}) + assert.Error(t, err) + }) +} + +func TestRPCDone(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockQueue := queue_mocks.NewMockQueue(t) + mockLogStore := log_mocks.NewMockService(t) + origLogStore := server.Config.Services.LogStore + server.Config.Services.LogStore = mockLogStore + t.Cleanup(func() { server.Config.Services.LogStore = origLogStore }) + + agent := defaultAgent() + repo := defaultRepo() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusRunning) + workflow.Children = []*model.Step{} + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("StepListFromWorkflowFind", mock.Anything).Return([]*model.Step{}, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(repo, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("WorkflowUpdate", mock.Anything).Return(nil) + mockStore.On("WorkflowGetTree", mock.Anything).Return([]*model.Workflow{}, nil) + mockStore.On("UpdatePipeline", mock.Anything).Return(nil) + mockStore.On("GetUser", mock.Anything).Return(nil, errors.New("user not found")) + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + mockQueue.On("Done", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + rpcInst := newTestRPC(t, mockStore) + rpcInst.queue = mockQueue + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Done(ctx, "30", rpc.WorkflowState{Started: 100, Finished: 200}) + assert.NoError(t, err) + }) + + t.Run("reject pipeline already succeeded", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusSuccess) + workflow := defaultWorkflow(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("StepListFromWorkflowFind", mock.Anything).Return([]*model.Step{}, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Done(ctx, "30", rpc.WorkflowState{Finished: 200}) + assert.ErrorIs(t, err, ErrAgentIllegalPipelineWorkflowReRunStateChange) + }) + + t.Run("reject pipeline killed", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusKilled) + workflow := defaultWorkflow(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("StepListFromWorkflowFind", mock.Anything).Return([]*model.Step{}, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Done(ctx, "30", rpc.WorkflowState{Finished: 200}) + assert.ErrorIs(t, err, ErrAgentIllegalPipelineWorkflowReRunStateChange) + }) + + t.Run("reject pipeline blocked", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusBlocked) + workflow := defaultWorkflow(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("StepListFromWorkflowFind", mock.Anything).Return([]*model.Step{}, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Done(ctx, "30", rpc.WorkflowState{Finished: 200}) + assert.ErrorIs(t, err, ErrAgentIllegalPipelineWorkflowRun) + }) + + t.Run("reject workflow already finished", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusSuccess) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("StepListFromWorkflowFind", mock.Anything).Return([]*model.Step{}, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Done(ctx, "30", rpc.WorkflowState{Finished: 200}) + assert.ErrorIs(t, err, ErrAgentIllegalWorkflowReRunStateChange) + }) + + t.Run("reject agent wrong org", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := orgAgent999() + pipeline := defaultPipeline(model.StatusRunning) + workflow := defaultWorkflow(model.StatusRunning) + + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("StepListFromWorkflowFind", mock.Anything).Return([]*model.Step{}, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentFind", int64(2)).Return(agent, nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "2")) + + err := rpcInst.Done(ctx, "30", rpc.WorkflowState{Finished: 200}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to interact") + }) + + t.Run("reject invalid workflow ID", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Done(ctx, "invalid", rpc.WorkflowState{}) + assert.Error(t, err) + }) +} + +func TestRPCLog(t *testing.T) { + // helper: a pipeline whose Finished timestamp is far enough in the past + // that it is outside the drain window, so log appending is rejected. + stalePipeline := func(status model.StatusValue) *model.Pipeline { + p := defaultPipeline(status) + p.Finished = time.Now().Add(-(logStreamDelayAllowed + time.Minute)).Unix() + return p + } + + // helper: a pipeline that finished very recently (within drain window). + recentPipeline := func(status model.StatusValue) *model.Pipeline { + p := defaultPipeline(status) + p.Finished = time.Now().Add(-30 * time.Second).Unix() + return p + } + + t.Run("happy path: step running, pipeline running", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockLogStore := log_mocks.NewMockService(t) + origLogStore := server.Config.Services.LogStore + server.Config.Services.LogStore = mockLogStore + t.Cleanup(func() { server.Config.Services.LogStore = origLogStore }) + + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + step := defaultStep(model.StatusRunning) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + mockLogStore.On("LogAppend", mock.Anything, mock.Anything).Return(nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + entries := []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Line: 0, Data: []byte("hello")}, + {StepUUID: "step-uuid-123", Line: 1, Data: []byte("world")}, + } + err := rpcInst.Log(ctx, "step-uuid-123", entries) + assert.NoError(t, err) + }) + + t.Run("allow: step finished but pipeline still running (logs draining)", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockLogStore := log_mocks.NewMockService(t) + origLogStore := server.Config.Services.LogStore + server.Config.Services.LogStore = mockLogStore + t.Cleanup(func() { server.Config.Services.LogStore = origLogStore }) + + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) // pipeline still running + step := defaultStep(model.StatusSuccess) // but step already finished + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + mockLogStore.On("LogAppend", mock.Anything, mock.Anything).Return(nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("late log")}, + }) + assert.NoError(t, err) + }) + + t.Run("allow: step running even though pipeline finished stale (step takes priority)", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockLogStore := log_mocks.NewMockService(t) + origLogStore := server.Config.Services.LogStore + server.Config.Services.LogStore = mockLogStore + t.Cleanup(func() { server.Config.Services.LogStore = origLogStore }) + + agent := defaultAgent() + pipeline := stalePipeline(model.StatusSuccess) // finished long ago + step := defaultStep(model.StatusRunning) // but step is still running + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + mockLogStore.On("LogAppend", mock.Anything, mock.Anything).Return(nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("running log")}, + }) + assert.NoError(t, err) + }) + + t.Run("allow: pipeline finished recently — within drain window", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockLogStore := log_mocks.NewMockService(t) + origLogStore := server.Config.Services.LogStore + server.Config.Services.LogStore = mockLogStore + t.Cleanup(func() { server.Config.Services.LogStore = origLogStore }) + + agent := defaultAgent() + pipeline := recentPipeline(model.StatusSuccess) // finished 30s ago + step := defaultStep(model.StatusSuccess) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + mockLogStore.On("LogAppend", mock.Anything, mock.Anything).Return(nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("drain log")}, + }) + assert.NoError(t, err) + }) + + t.Run("reject: pipeline finished stale and step not running", func(t *testing.T) { + // This replaces the old "reject pipeline already finished" test. + // Previously the rejection came from checkPipelineState returning + // ErrAgentIllegalPipelineWorkflowReRunStateChange. + // Now it comes from allowAppendingLogs returning ErrAgentIllegalLogStreaming. + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := stalePipeline(model.StatusSuccess) + step := defaultStep(model.StatusSuccess) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("test")}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "can not alter logs") + assert.ErrorIs(t, err, ErrAgentIllegalLogStreaming) + // The old error is no longer returned from Log() — allowAppendingLogs + // now handles the pipeline-finished case itself. + assert.False(t, errors.Is(err, ErrAgentIllegalPipelineWorkflowReRunStateChange)) + }) + + t.Run("reject: pipeline failed stale and step not running", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := stalePipeline(model.StatusFailure) + step := defaultStep(model.StatusFailure) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("test")}, + }) + require.Error(t, err) + assert.ErrorIs(t, err, ErrAgentIllegalLogStreaming) + }) + + t.Run("reject: step pending (not running), pipeline not running, outside drain window", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := stalePipeline(model.StatusKilled) + step := defaultStep(model.StatusPending) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("test")}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "can not alter logs") + assert.ErrorIs(t, err, ErrAgentIllegalLogStreaming) + }) + + t.Run("reject: step already succeeded, pipeline succeeded stale", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := stalePipeline(model.StatusSuccess) + step := defaultStep(model.StatusSuccess) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("test")}, + }) + require.Error(t, err) + assert.ErrorIs(t, err, ErrAgentIllegalLogStreaming) + }) + + t.Run("reject: step killed, pipeline killed stale", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := defaultAgent() + pipeline := stalePipeline(model.StatusKilled) + step := defaultStep(model.StatusKilled) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("test")}, + }) + require.Error(t, err) + assert.ErrorIs(t, err, ErrAgentIllegalLogStreaming) + }) + + t.Run("reject mismatched step UUID in log entry", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockLogStore := log_mocks.NewMockService(t) + origLogStore := server.Config.Services.LogStore + server.Config.Services.LogStore = mockLogStore + t.Cleanup(func() { server.Config.Services.LogStore = origLogStore }) + + agent := defaultAgent() + pipeline := defaultPipeline(model.StatusRunning) + step := defaultStep(model.StatusRunning) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(1)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + // Second entry has a rogue UUID — agent trying to inject into another step. + entries := []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Line: 0, Data: []byte("ok")}, + {StepUUID: "DIFFERENT-UUID", Line: 1, Data: []byte("injected!")}, + } + err := rpcInst.Log(ctx, "step-uuid-123", entries) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected step UUID") + }) + + t.Run("reject agent wrong org", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := orgAgent999() + pipeline := defaultPipeline(model.StatusRunning) + step := defaultStep(model.StatusRunning) + + mockStore.On("StepByUUID", "step-uuid-123").Return(step, nil) + mockStore.On("AgentFind", int64(2)).Return(agent, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "2")) + + err := rpcInst.Log(ctx, "step-uuid-123", []*rpc.LogEntry{ + {StepUUID: "step-uuid-123", Data: []byte("test")}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to interact") + }) + + t.Run("reject nonexistent step UUID", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + mockStore.On("StepByUUID", "nonexistent").Return(nil, errors.New("not found")) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "1")) + + err := rpcInst.Log(ctx, "nonexistent", []*rpc.LogEntry{ + {StepUUID: "nonexistent", Data: []byte("test")}, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not find step") + }) +} + +func TestRPCExtend(t *testing.T) { + t.Run("reject agent wrong org via permission check", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := orgAgent999() + workflow := defaultWorkflow(model.StatusRunning) + pipeline := defaultPipeline(model.StatusRunning) + + mockStore.On("AgentFind", int64(2)).Return(agent, nil) + mockStore.On("AgentUpdate", mock.Anything).Return(nil) + // checkAgentPermissionByWorkflow with nil pipeline/repo -> loads from store + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "2")) + + err := rpcInst.Extend(ctx, "30") + require.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to interact") + }) +} + +func TestRPCWait(t *testing.T) { + t.Run("reject agent wrong org", func(t *testing.T) { + mockStore := store_mocks.NewMockStore(t) + agent := orgAgent999() + workflow := defaultWorkflow(model.StatusRunning) + pipeline := defaultPipeline(model.StatusRunning) + + mockStore.On("AgentFind", int64(2)).Return(agent, nil) + // checkAgentPermissionByWorkflow loads from store + mockStore.On("WorkflowLoad", int64(30)).Return(workflow, nil) + mockStore.On("GetPipeline", int64(20)).Return(pipeline, nil) + mockStore.On("GetRepo", int64(10)).Return(defaultRepo(), nil) + + rpcInst := newTestRPC(t, mockStore) + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("agent_id", "2")) + + _, err := rpcInst.Wait(ctx, "30") + require.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to interact") + }) +} diff --git a/server/rpc/rpc_test.go b/server/rpc/rpc_test.go index ef1eb4df7..08caad5c4 100644 --- a/server/rpc/rpc_test.go +++ b/server/rpc/rpc_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package grpc +package rpc import ( "testing" diff --git a/server/rpc/sanitize.go b/server/rpc/sanitize.go new file mode 100644 index 000000000..84af183a8 --- /dev/null +++ b/server/rpc/sanitize.go @@ -0,0 +1,144 @@ +// Copyright 2026 Woodpecker Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rpc + +import ( + "context" + "errors" + "fmt" + "strconv" + "time" + + "github.com/rs/zerolog/log" + + "go.woodpecker-ci.org/woodpecker/v3/server/model" +) + +const logStreamDelayAllowed = 5 * time.Minute + +func (s *RPC) checkAgentPermissionByWorkflow(_ context.Context, agent *model.Agent, strWorkflowID string, pipeline *model.Pipeline, repo *model.Repo) error { + var err error + if repo == nil && pipeline == nil { + workflowID, err := strconv.ParseInt(strWorkflowID, 10, 64) + if err != nil { + return err + } + + workflow, err := s.store.WorkflowLoad(workflowID) + if err != nil { + log.Error().Err(err).Msgf("cannot find workflow with id %d", workflowID) + return err + } + + pipeline, err = s.store.GetPipeline(workflow.PipelineID) + if err != nil { + log.Error().Err(err).Msgf("cannot find pipeline with id %d", workflow.PipelineID) + return err + } + } + + if repo == nil { + repo, err = s.store.GetRepo(pipeline.RepoID) + if err != nil { + log.Error().Err(err).Msgf("cannot find repo with id %d", pipeline.RepoID) + return err + } + } + + if agent.CanAccessRepo(repo) { + return nil + } + + msg := fmt.Sprintf("agent '%d' is not allowed to interact with repo[%d] '%s'", agent.ID, repo.ID, repo.FullName) + log.Error().Int64("repoId", repo.ID).Msg(msg) + return errors.New(msg) +} + +// checkPipelineState checks if an agent is allowed to change/update a workflow/pipeline state +// by the state the parent pipeline is in. +func checkPipelineState(currPipeline *model.Pipeline) (err error) { + // check if pipeline was already run and marked finished or is blocked + switch currPipeline.Status { + case model.StatusCreated, + model.StatusPending, + model.StatusRunning: + break + + case model.StatusBlocked: + err = ErrAgentIllegalPipelineWorkflowRun + + default: + err = ErrAgentIllegalPipelineWorkflowReRunStateChange + } + + if err != nil { + log.Error().Err(err).Msg("caught agent performing illegal instruction") + } + return err +} + +// checkWorkflowStepStates checks if a workflow/step state or its logs can be altered +// depending on what state the workflow and step currently is in. +func checkWorkflowStepStates(currWorkflow *model.Workflow, currStep *model.Step) (err error) { + if currWorkflow != nil { + switch currWorkflow.State { + case model.StatusCreated, + model.StatusPending, + model.StatusRunning: + break + + case model.StatusBlocked: + err = ErrAgentIllegalWorkflowRun + + default: + err = ErrAgentIllegalWorkflowReRunStateChange + } + } + + if currStep != nil { + switch currStep.State { + case model.StatusCreated, + model.StatusPending, + model.StatusRunning: + break + + case model.StatusBlocked: + err = errors.Join(err, ErrAgentIllegalStepRun) + + default: + err = errors.Join(err, ErrAgentIllegalStepReRunStateChange) + } + } + + if err != nil { + log.Error().Err(err).Msg("caught agent performing illegal instruction") + } + return err +} + +func allowAppendingLogs(currPipeline *model.Pipeline, currStep *model.Step) error { + // As long as pipeline is running just let the agent send logs + if currStep.State == model.StatusRunning || currPipeline.Status == model.StatusRunning { + return nil + } + // else give some delay where log caches can drain and be send ... because of network outage / server restart / ... + if time.Unix(currPipeline.Finished, 0).Add(logStreamDelayAllowed).After(time.Now()) { + return nil + } + + err := ErrAgentIllegalLogStreaming + log.Error().Err(err).Msg("caught agent performing illegal instruction") + return err +} diff --git a/server/rpc/sanitize_test.go b/server/rpc/sanitize_test.go new file mode 100644 index 000000000..6791bbf25 --- /dev/null +++ b/server/rpc/sanitize_test.go @@ -0,0 +1,433 @@ +// Copyright 2026 Woodpecker Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rpc + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "go.woodpecker-ci.org/woodpecker/v3/server/model" +) + +func TestCheckPipelineState(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + status model.StatusValue + wantErr error + expectNil bool + }{ + { + name: "created is allowed", + status: model.StatusCreated, + expectNil: true, + }, + { + name: "pending is allowed", + status: model.StatusPending, + expectNil: true, + }, + { + name: "running is allowed", + status: model.StatusRunning, + expectNil: true, + }, + { + name: "blocked is rejected", + status: model.StatusBlocked, + wantErr: ErrAgentIllegalPipelineWorkflowRun, + }, + { + name: "success is rejected as re-run", + status: model.StatusSuccess, + wantErr: ErrAgentIllegalPipelineWorkflowReRunStateChange, + }, + { + name: "failure is rejected as re-run", + status: model.StatusFailure, + wantErr: ErrAgentIllegalPipelineWorkflowReRunStateChange, + }, + { + name: "killed is rejected as re-run", + status: model.StatusKilled, + wantErr: ErrAgentIllegalPipelineWorkflowReRunStateChange, + }, + { + name: "error is rejected as re-run", + status: model.StatusError, + wantErr: ErrAgentIllegalPipelineWorkflowReRunStateChange, + }, + { + name: "skipped is rejected as re-run", + status: model.StatusSkipped, + wantErr: ErrAgentIllegalPipelineWorkflowReRunStateChange, + }, + { + name: "declined is rejected as re-run", + status: model.StatusDeclined, + wantErr: ErrAgentIllegalPipelineWorkflowReRunStateChange, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pipeline := &model.Pipeline{Status: tt.status} + err := checkPipelineState(pipeline) + + if tt.expectNil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, tt.wantErr) + } + }) + } +} + +func TestCheckWorkflowStepStates(t *testing.T) { + t.Parallel() + + t.Run("workflow only", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + state model.StatusValue + wantErr error + }{ + {"created allows", model.StatusCreated, nil}, + {"pending allows", model.StatusPending, nil}, + {"running allows", model.StatusRunning, nil}, + {"blocked rejects", model.StatusBlocked, ErrAgentIllegalWorkflowRun}, + {"success rejects", model.StatusSuccess, ErrAgentIllegalWorkflowReRunStateChange}, + {"failure rejects", model.StatusFailure, ErrAgentIllegalWorkflowReRunStateChange}, + {"killed rejects", model.StatusKilled, ErrAgentIllegalWorkflowReRunStateChange}, + {"error rejects", model.StatusError, ErrAgentIllegalWorkflowReRunStateChange}, + {"skipped rejects", model.StatusSkipped, ErrAgentIllegalWorkflowReRunStateChange}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + workflow := &model.Workflow{State: tt.state} + err := checkWorkflowStepStates(workflow, nil) + + if tt.wantErr == nil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, tt.wantErr) + } + }) + } + }) + + t.Run("step only (nil workflow)", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + state model.StatusValue + wantErr error + }{ + {"created allows", model.StatusCreated, nil}, + {"pending allows", model.StatusPending, nil}, + {"running allows", model.StatusRunning, nil}, + {"blocked rejects", model.StatusBlocked, ErrAgentIllegalStepRun}, + {"success rejects", model.StatusSuccess, ErrAgentIllegalStepReRunStateChange}, + {"failure rejects", model.StatusFailure, ErrAgentIllegalStepReRunStateChange}, + {"killed rejects", model.StatusKilled, ErrAgentIllegalStepReRunStateChange}, + {"error rejects", model.StatusError, ErrAgentIllegalStepReRunStateChange}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + step := &model.Step{State: tt.state} + err := checkWorkflowStepStates(nil, step) + + if tt.wantErr == nil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, tt.wantErr) + } + }) + } + }) + + t.Run("nil workflow and nil step", func(t *testing.T) { + t.Parallel() + + assert.NoError(t, checkWorkflowStepStates(nil, nil)) + }) + + t.Run("workflow running, step running", func(t *testing.T) { + t.Parallel() + + workflow := &model.Workflow{State: model.StatusRunning} + step := &model.Step{State: model.StatusRunning} + assert.NoError(t, checkWorkflowStepStates(workflow, step)) + }) + + t.Run("workflow running, step finished", func(t *testing.T) { + t.Parallel() + + workflow := &model.Workflow{State: model.StatusRunning} + step := &model.Step{State: model.StatusSuccess} + err := checkWorkflowStepStates(workflow, step) + assert.ErrorIs(t, err, ErrAgentIllegalStepReRunStateChange) + // should not contain workflow error + assert.False(t, errors.Is(err, ErrAgentIllegalWorkflowReRunStateChange)) + }) + + t.Run("workflow running, step blocked", func(t *testing.T) { + t.Parallel() + + workflow := &model.Workflow{State: model.StatusRunning} + step := &model.Step{State: model.StatusBlocked} + err := checkWorkflowStepStates(workflow, step) + assert.ErrorIs(t, err, ErrAgentIllegalStepRun) + }) + + t.Run("both finished - joined errors", func(t *testing.T) { + t.Parallel() + + workflow := &model.Workflow{State: model.StatusSuccess} + step := &model.Step{State: model.StatusSuccess} + err := checkWorkflowStepStates(workflow, step) + assert.ErrorIs(t, err, ErrAgentIllegalWorkflowReRunStateChange) + assert.ErrorIs(t, err, ErrAgentIllegalStepReRunStateChange) + }) + + t.Run("both blocked - joined errors", func(t *testing.T) { + t.Parallel() + + workflow := &model.Workflow{State: model.StatusBlocked} + step := &model.Step{State: model.StatusBlocked} + err := checkWorkflowStepStates(workflow, step) + assert.ErrorIs(t, err, ErrAgentIllegalWorkflowRun) + assert.ErrorIs(t, err, ErrAgentIllegalStepRun) + }) + + t.Run("workflow finished, step blocked - joined errors", func(t *testing.T) { + t.Parallel() + + workflow := &model.Workflow{State: model.StatusKilled} + step := &model.Step{State: model.StatusBlocked} + err := checkWorkflowStepStates(workflow, step) + assert.ErrorIs(t, err, ErrAgentIllegalWorkflowReRunStateChange) + assert.ErrorIs(t, err, ErrAgentIllegalStepRun) + }) + + t.Run("workflow finished (failure), step finished (failure) - joined errors", func(t *testing.T) { + t.Parallel() + + workflow := &model.Workflow{State: model.StatusFailure} + step := &model.Step{State: model.StatusFailure} + err := checkWorkflowStepStates(workflow, step) + assert.ErrorIs(t, err, ErrAgentIllegalWorkflowReRunStateChange) + assert.ErrorIs(t, err, ErrAgentIllegalStepReRunStateChange) + }) +} + +// AllowAppendingLogs — updated for the new (pipeline, step) signature +// +// New logic: +// Allow if step.State == Running (step is actively running) +// Allow if pipeline.Status == Running (pipeline still running, step may +// have just finished but pipeline hasn't caught up yet) +// Allow if pipeline.Finished is within the last logStreamDelayAllowed +// (drain window after a server restart / network blip) +// Reject otherwise. + +func TestAllowAppendingLogs(t *testing.T) { + t.Parallel() + + // recentFinish is a pipeline.Finished timestamp just 30 seconds ago — + // well within the 5-minute drain window. + recentFinish := time.Now().Add(-30 * time.Second).Unix() + + // staleFinish is a pipeline.Finished timestamp 10 minutes ago — + // outside the drain window. + staleFinish := time.Now().Add(-10 * time.Minute).Unix() + + tests := []struct { + name string + pipelineStatus model.StatusValue + pipelineFinish int64 + stepState model.StatusValue + wantErr error + }{ + // --- step is running: always allowed regardless of pipeline state ---- + { + name: "step running, pipeline running → allow", + pipelineStatus: model.StatusRunning, + stepState: model.StatusRunning, + }, + { + name: "step running, pipeline success → allow (step takes priority)", + pipelineStatus: model.StatusSuccess, + pipelineFinish: staleFinish, + stepState: model.StatusRunning, + }, + { + name: "step running, pipeline failure → allow", + pipelineStatus: model.StatusFailure, + pipelineFinish: staleFinish, + stepState: model.StatusRunning, + }, + { + name: "step running, pipeline killed → allow", + pipelineStatus: model.StatusKilled, + pipelineFinish: staleFinish, + stepState: model.StatusRunning, + }, + + // --- pipeline still running: allow even if step finished ------------ + { + name: "step success, pipeline still running → allow", + pipelineStatus: model.StatusRunning, + stepState: model.StatusSuccess, + }, + { + name: "step failure, pipeline still running → allow", + pipelineStatus: model.StatusRunning, + stepState: model.StatusFailure, + }, + { + name: "step pending, pipeline still running → allow", + pipelineStatus: model.StatusRunning, + stepState: model.StatusPending, + }, + { + name: "step killed, pipeline still running → allow", + pipelineStatus: model.StatusRunning, + stepState: model.StatusKilled, + }, + + // --- pipeline finished recently: drain window allows logs ----------- + { + name: "step success, pipeline finished recently → allow (drain window)", + pipelineStatus: model.StatusSuccess, + pipelineFinish: recentFinish, + stepState: model.StatusSuccess, + }, + { + name: "step failure, pipeline failed recently → allow (drain window)", + pipelineStatus: model.StatusFailure, + pipelineFinish: recentFinish, + stepState: model.StatusFailure, + }, + { + name: "step pending, pipeline killed recently → allow (drain window)", + pipelineStatus: model.StatusKilled, + pipelineFinish: recentFinish, + stepState: model.StatusPending, + }, + + // --- pipeline finished and drain window expired: reject ------------- + { + name: "step success, pipeline success, stale finish → reject", + pipelineStatus: model.StatusSuccess, + pipelineFinish: staleFinish, + stepState: model.StatusSuccess, + wantErr: ErrAgentIllegalLogStreaming, + }, + { + name: "step failure, pipeline failure, stale finish → reject", + pipelineStatus: model.StatusFailure, + pipelineFinish: staleFinish, + stepState: model.StatusFailure, + wantErr: ErrAgentIllegalLogStreaming, + }, + { + name: "step pending, pipeline killed, stale finish → reject", + pipelineStatus: model.StatusKilled, + pipelineFinish: staleFinish, + stepState: model.StatusPending, + wantErr: ErrAgentIllegalLogStreaming, + }, + { + name: "step created, pipeline error, stale finish → reject", + pipelineStatus: model.StatusError, + pipelineFinish: staleFinish, + stepState: model.StatusCreated, + wantErr: ErrAgentIllegalLogStreaming, + }, + + // --- zero Finished timestamp (never recorded): outside drain window - + { + name: "step success, pipeline success, Finished=0 → reject", + pipelineStatus: model.StatusSuccess, + pipelineFinish: 0, + stepState: model.StatusSuccess, + wantErr: ErrAgentIllegalLogStreaming, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pipeline := &model.Pipeline{ + Status: tt.pipelineStatus, + Finished: tt.pipelineFinish, + } + step := &model.Step{State: tt.stepState} + + err := allowAppendingLogs(pipeline, step) + + if tt.wantErr == nil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, tt.wantErr) + } + }) + } +} + +// TestAllowAppendingLogsDrainBoundary checks the exact boundary of the +// 5-minute drain window to guard against off-by-one errors. +func TestAllowAppendingLogsDrainBoundary(t *testing.T) { + t.Parallel() + + step := &model.Step{State: model.StatusSuccess} + + t.Run("finished exactly at drain window boundary is allowed", func(t *testing.T) { + t.Parallel() + + // Finished just barely inside the window (1 second of headroom). + finishedAt := time.Now().Add(-(logStreamDelayAllowed - time.Second)).Unix() + pipeline := &model.Pipeline{Status: model.StatusSuccess, Finished: finishedAt} + + assert.NoError(t, allowAppendingLogs(pipeline, step)) + }) + + t.Run("finished just outside drain window is rejected", func(t *testing.T) { + t.Parallel() + + // Finished 1 second past the allowed window. + finishedAt := time.Now().Add(-(logStreamDelayAllowed + time.Second)).Unix() + pipeline := &model.Pipeline{Status: model.StatusSuccess, Finished: finishedAt} + + assert.ErrorIs(t, allowAppendingLogs(pipeline, step), ErrAgentIllegalLogStreaming) + }) +} diff --git a/server/rpc/server.go b/server/rpc/server.go index d9cee658e..7f39e086c 100644 --- a/server/rpc/server.go +++ b/server/rpc/server.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package grpc +package rpc import ( "context"