diff --git a/internal/pkg/callbacks/rolling_upgrade.go b/internal/pkg/callbacks/rolling_upgrade.go index 4ef9159..77d7982 100644 --- a/internal/pkg/callbacks/rolling_upgrade.go +++ b/internal/pkg/callbacks/rolling_upgrade.go @@ -2,6 +2,7 @@ package callbacks import ( "context" + "errors" "fmt" "time" @@ -19,6 +20,9 @@ import ( openshiftv1 "github.com/openshift/api/apps/v1" ) +// ItemFunc is a generic function to return a specific resource in given namespace +type ItemFunc func(kube.Clients, string, string) (runtime.Object, error) + // ItemsFunc is a generic function to return a specific resource array in given namespace type ItemsFunc func(kube.Clients, string) []runtime.Object @@ -34,6 +38,12 @@ type VolumesFunc func(runtime.Object) []v1.Volume // UpdateFunc performs the resource update type UpdateFunc func(kube.Clients, string, runtime.Object) error +// PatchFunc performs the resource patch +type PatchFunc func(kube.Clients, string, runtime.Object, patchtypes.PatchType, []byte) error + +// PatchTemplateFunc is a generic func to return strategic merge JSON patch template +type PatchTemplatesFunc func() PatchTemplates + // AnnotationsFunc is a generic func to return annotations type AnnotationsFunc func(runtime.Object) map[string]string @@ -42,14 +52,42 @@ type PodAnnotationsFunc func(runtime.Object) map[string]string // RollingUpgradeFuncs contains generic functions to perform rolling upgrade type RollingUpgradeFuncs struct { - ItemsFunc ItemsFunc - AnnotationsFunc AnnotationsFunc - PodAnnotationsFunc PodAnnotationsFunc - ContainersFunc ContainersFunc - InitContainersFunc InitContainersFunc - UpdateFunc UpdateFunc - VolumesFunc VolumesFunc - ResourceType string + ItemFunc ItemFunc + ItemsFunc ItemsFunc + AnnotationsFunc AnnotationsFunc + PodAnnotationsFunc PodAnnotationsFunc + ContainersFunc ContainersFunc + ContainerPatchPathFunc ContainersFunc + InitContainersFunc InitContainersFunc + UpdateFunc UpdateFunc + PatchFunc PatchFunc + PatchTemplatesFunc PatchTemplatesFunc + VolumesFunc VolumesFunc + ResourceType string + SupportsPatch bool +} + +// PatchTemplates contains merge JSON patch templates +type PatchTemplates struct { + AnnotationTemplate string + EnvVarTemplate string + DeleteEnvVarTemplate string +} + +// GetDeploymentItem returns the deployment in given namespace +func GetDeploymentItem(clients kube.Clients, name string, namespace string) (runtime.Object, error) { + deployment, err := clients.KubernetesClient.AppsV1().Deployments(namespace).Get(context.TODO(), name, meta_v1.GetOptions{}) + if err != nil { + logrus.Errorf("Failed to get deployment %v", err) + return nil, err + } + + if deployment.Spec.Template.ObjectMeta.Annotations == nil { + annotations := make(map[string]string) + deployment.Spec.Template.ObjectMeta.Annotations = annotations + } + + return deployment, nil } // GetDeploymentItems returns the deployments in given namespace @@ -72,6 +110,17 @@ func GetDeploymentItems(clients kube.Clients, namespace string) []runtime.Object return items } +// GetCronJobItem returns the job in given namespace +func GetCronJobItem(clients kube.Clients, name string, namespace string) (runtime.Object, error) { + cronjob, err := clients.KubernetesClient.BatchV1().CronJobs(namespace).Get(context.TODO(), name, meta_v1.GetOptions{}) + if err != nil { + logrus.Errorf("Failed to get cronjob %v", err) + return nil, err + } + + return cronjob, nil +} + // GetCronJobItems returns the jobs in given namespace func GetCronJobItems(clients kube.Clients, namespace string) []runtime.Object { cronjobs, err := clients.KubernetesClient.BatchV1().CronJobs(namespace).List(context.TODO(), meta_v1.ListOptions{}) @@ -92,6 +141,17 @@ func GetCronJobItems(clients kube.Clients, namespace string) []runtime.Object { return items } +// GetJobItem returns the job in given namespace +func GetJobItem(clients kube.Clients, name string, namespace string) (runtime.Object, error) { + job, err := clients.KubernetesClient.BatchV1().Jobs(namespace).Get(context.TODO(), name, meta_v1.GetOptions{}) + if err != nil { + logrus.Errorf("Failed to get job %v", err) + return nil, err + } + + return job, nil +} + // GetJobItems returns the jobs in given namespace func GetJobItems(clients kube.Clients, namespace string) []runtime.Object { jobs, err := clients.KubernetesClient.BatchV1().Jobs(namespace).List(context.TODO(), meta_v1.ListOptions{}) @@ -112,6 +172,17 @@ func GetJobItems(clients kube.Clients, namespace string) []runtime.Object { return items } +// GetDaemonSetItem returns the daemonSet in given namespace +func GetDaemonSetItem(clients kube.Clients, name string, namespace string) (runtime.Object, error) { + daemonSet, err := clients.KubernetesClient.AppsV1().DaemonSets(namespace).Get(context.TODO(), name, meta_v1.GetOptions{}) + if err != nil { + logrus.Errorf("Failed to get daemonSet %v", err) + return nil, err + } + + return daemonSet, nil +} + // GetDaemonSetItems returns the daemonSets in given namespace func GetDaemonSetItems(clients kube.Clients, namespace string) []runtime.Object { daemonSets, err := clients.KubernetesClient.AppsV1().DaemonSets(namespace).List(context.TODO(), meta_v1.ListOptions{}) @@ -131,6 +202,17 @@ func GetDaemonSetItems(clients kube.Clients, namespace string) []runtime.Object return items } +// GetStatefulSetItem returns the statefulSet in given namespace +func GetStatefulSetItem(clients kube.Clients, name string, namespace string) (runtime.Object, error) { + statefulSet, err := clients.KubernetesClient.AppsV1().StatefulSets(namespace).Get(context.TODO(), name, meta_v1.GetOptions{}) + if err != nil { + logrus.Errorf("Failed to get statefulSet %v", err) + return nil, err + } + + return statefulSet, nil +} + // GetStatefulSetItems returns the statefulSets in given namespace func GetStatefulSetItems(clients kube.Clients, namespace string) []runtime.Object { statefulSets, err := clients.KubernetesClient.AppsV1().StatefulSets(namespace).List(context.TODO(), meta_v1.ListOptions{}) @@ -150,6 +232,17 @@ func GetStatefulSetItems(clients kube.Clients, namespace string) []runtime.Objec return items } +// GetDeploymentConfigItem returns the deploymentConfig in given namespace +func GetDeploymentConfigItem(clients kube.Clients, name string, namespace string) (runtime.Object, error) { + deploymentConfig, err := clients.OpenshiftAppsClient.AppsV1().DeploymentConfigs(namespace).Get(context.TODO(), name, meta_v1.GetOptions{}) + if err != nil { + logrus.Errorf("Failed to get deploymentConfig %v", err) + return nil, err + } + + return deploymentConfig, nil +} + // GetDeploymentConfigItems returns the deploymentConfigs in given namespace func GetDeploymentConfigItems(clients kube.Clients, namespace string) []runtime.Object { deploymentConfigs, err := clients.OpenshiftAppsClient.AppsV1().DeploymentConfigs(namespace).List(context.TODO(), meta_v1.ListOptions{}) @@ -169,6 +262,17 @@ func GetDeploymentConfigItems(clients kube.Clients, namespace string) []runtime. return items } +// GetRolloutItem returns the rollout in given namespace +func GetRolloutItem(clients kube.Clients, name string, namespace string) (runtime.Object, error) { + rollout, err := clients.ArgoRolloutClient.ArgoprojV1alpha1().Rollouts(namespace).Get(context.TODO(), name, meta_v1.GetOptions{}) + if err != nil { + logrus.Errorf("Failed to get Rollout %v", err) + return nil, err + } + + return rollout, nil +} + // GetRolloutItems returns the rollouts in given namespace func GetRolloutItems(clients kube.Clients, namespace string) []runtime.Object { rollouts, err := clients.ArgoRolloutClient.ArgoprojV1alpha1().Rollouts(namespace).List(context.TODO(), meta_v1.ListOptions{}) @@ -190,71 +294,113 @@ func GetRolloutItems(clients kube.Clients, namespace string) []runtime.Object { // GetDeploymentAnnotations returns the annotations of given deployment func GetDeploymentAnnotations(item runtime.Object) map[string]string { + if item.(*appsv1.Deployment).ObjectMeta.Annotations == nil { + item.(*appsv1.Deployment).ObjectMeta.Annotations = make(map[string]string) + } return item.(*appsv1.Deployment).ObjectMeta.Annotations } // GetCronJobAnnotations returns the annotations of given cronjob func GetCronJobAnnotations(item runtime.Object) map[string]string { + if item.(*batchv1.CronJob).ObjectMeta.Annotations == nil { + item.(*batchv1.CronJob).ObjectMeta.Annotations = make(map[string]string) + } return item.(*batchv1.CronJob).ObjectMeta.Annotations } // GetJobAnnotations returns the annotations of given job func GetJobAnnotations(item runtime.Object) map[string]string { + if item.(*batchv1.Job).ObjectMeta.Annotations == nil { + item.(*batchv1.Job).ObjectMeta.Annotations = make(map[string]string) + } return item.(*batchv1.Job).ObjectMeta.Annotations } // GetDaemonSetAnnotations returns the annotations of given daemonSet func GetDaemonSetAnnotations(item runtime.Object) map[string]string { + if item.(*appsv1.DaemonSet).ObjectMeta.Annotations == nil { + item.(*appsv1.DaemonSet).ObjectMeta.Annotations = make(map[string]string) + } return item.(*appsv1.DaemonSet).ObjectMeta.Annotations } // GetStatefulSetAnnotations returns the annotations of given statefulSet func GetStatefulSetAnnotations(item runtime.Object) map[string]string { + if item.(*appsv1.StatefulSet).ObjectMeta.Annotations == nil { + item.(*appsv1.StatefulSet).ObjectMeta.Annotations = make(map[string]string) + } return item.(*appsv1.StatefulSet).ObjectMeta.Annotations } // GetDeploymentConfigAnnotations returns the annotations of given deploymentConfig func GetDeploymentConfigAnnotations(item runtime.Object) map[string]string { + if item.(*openshiftv1.DeploymentConfig).ObjectMeta.Annotations == nil { + item.(*openshiftv1.DeploymentConfig).ObjectMeta.Annotations = make(map[string]string) + } return item.(*openshiftv1.DeploymentConfig).ObjectMeta.Annotations } // GetRolloutAnnotations returns the annotations of given rollout func GetRolloutAnnotations(item runtime.Object) map[string]string { + if item.(*argorolloutv1alpha1.Rollout).ObjectMeta.Annotations == nil { + item.(*argorolloutv1alpha1.Rollout).ObjectMeta.Annotations = make(map[string]string) + } return item.(*argorolloutv1alpha1.Rollout).ObjectMeta.Annotations } // GetDeploymentPodAnnotations returns the pod's annotations of given deployment func GetDeploymentPodAnnotations(item runtime.Object) map[string]string { + if item.(*appsv1.Deployment).Spec.Template.ObjectMeta.Annotations == nil { + item.(*appsv1.Deployment).Spec.Template.ObjectMeta.Annotations = make(map[string]string) + } return item.(*appsv1.Deployment).Spec.Template.ObjectMeta.Annotations } // GetCronJobPodAnnotations returns the pod's annotations of given cronjob func GetCronJobPodAnnotations(item runtime.Object) map[string]string { + if item.(*batchv1.CronJob).Spec.JobTemplate.Spec.Template.ObjectMeta.Annotations == nil { + item.(*batchv1.CronJob).Spec.JobTemplate.Spec.Template.ObjectMeta.Annotations = make(map[string]string) + } return item.(*batchv1.CronJob).Spec.JobTemplate.Spec.Template.ObjectMeta.Annotations } // GetJobPodAnnotations returns the pod's annotations of given job func GetJobPodAnnotations(item runtime.Object) map[string]string { + if item.(*batchv1.Job).Spec.Template.ObjectMeta.Annotations == nil { + item.(*batchv1.Job).Spec.Template.ObjectMeta.Annotations = make(map[string]string) + } return item.(*batchv1.Job).Spec.Template.ObjectMeta.Annotations } // GetDaemonSetPodAnnotations returns the pod's annotations of given daemonSet func GetDaemonSetPodAnnotations(item runtime.Object) map[string]string { + if item.(*appsv1.DaemonSet).Spec.Template.ObjectMeta.Annotations == nil { + item.(*appsv1.DaemonSet).Spec.Template.ObjectMeta.Annotations = make(map[string]string) + } return item.(*appsv1.DaemonSet).Spec.Template.ObjectMeta.Annotations } // GetStatefulSetPodAnnotations returns the pod's annotations of given statefulSet func GetStatefulSetPodAnnotations(item runtime.Object) map[string]string { + if item.(*appsv1.StatefulSet).Spec.Template.ObjectMeta.Annotations == nil { + item.(*appsv1.StatefulSet).Spec.Template.ObjectMeta.Annotations = make(map[string]string) + } return item.(*appsv1.StatefulSet).Spec.Template.ObjectMeta.Annotations } // GetDeploymentConfigPodAnnotations returns the pod's annotations of given deploymentConfig func GetDeploymentConfigPodAnnotations(item runtime.Object) map[string]string { + if item.(*openshiftv1.DeploymentConfig).Spec.Template.ObjectMeta.Annotations == nil { + item.(*openshiftv1.DeploymentConfig).Spec.Template.ObjectMeta.Annotations = make(map[string]string) + } return item.(*openshiftv1.DeploymentConfig).Spec.Template.ObjectMeta.Annotations } // GetRolloutPodAnnotations returns the pod's annotations of given rollout func GetRolloutPodAnnotations(item runtime.Object) map[string]string { + if item.(*argorolloutv1alpha1.Rollout).Spec.Template.ObjectMeta.Annotations == nil { + item.(*argorolloutv1alpha1.Rollout).Spec.Template.ObjectMeta.Annotations = make(map[string]string) + } return item.(*argorolloutv1alpha1.Rollout).Spec.Template.ObjectMeta.Annotations } @@ -328,6 +474,15 @@ func GetRolloutInitContainers(item runtime.Object) []v1.Container { return item.(*argorolloutv1alpha1.Rollout).Spec.Template.Spec.InitContainers } +// GetPatchTemplates returns patch templates +func GetPatchTemplates() PatchTemplates { + return PatchTemplates{ + AnnotationTemplate: `{"spec":{"template":{"metadata":{"annotations":{"%s":"%s"}}}}}`, // strategic merge patch + EnvVarTemplate: `{"spec":{"template":{"spec":{"containers":[{"name":"%s","env":[{"name":"%s","value":"%s"}]}]}}}}`, // strategic merge patch + DeleteEnvVarTemplate: `[{"op":"remove","path":"/spec/template/spec/containers/%d/env/%d"}]`, // JSON patch + } +} + // UpdateDeployment performs rolling upgrade on deployment func UpdateDeployment(clients kube.Clients, namespace string, resource runtime.Object) error { deployment := resource.(*appsv1.Deployment) @@ -335,6 +490,13 @@ func UpdateDeployment(clients kube.Clients, namespace string, resource runtime.O return err } +// PatchDeployment performs rolling upgrade on deployment +func PatchDeployment(clients kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + deployment := resource.(*appsv1.Deployment) + _, err := clients.KubernetesClient.AppsV1().Deployments(namespace).Patch(context.TODO(), deployment.Name, patchType, bytes, meta_v1.PatchOptions{FieldManager: "Reloader"}) + return err +} + // CreateJobFromCronjob performs rolling upgrade on cronjob func CreateJobFromCronjob(clients kube.Clients, namespace string, resource runtime.Object) error { cronJob := resource.(*batchv1.CronJob) @@ -347,6 +509,10 @@ func CreateJobFromCronjob(clients kube.Clients, namespace string, resource runti return err } +func PatchCronJob(clients kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + return errors.New("not supported patching: CronJob") +} + // ReCreateJobFromjob performs rolling upgrade on job func ReCreateJobFromjob(clients kube.Clients, namespace string, resource runtime.Object) error { oldJob := resource.(*batchv1.Job) @@ -379,6 +545,10 @@ func ReCreateJobFromjob(clients kube.Clients, namespace string, resource runtime return err } +func PatchJob(clients kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + return errors.New("not supported patching: Job") +} + // UpdateDaemonSet performs rolling upgrade on daemonSet func UpdateDaemonSet(clients kube.Clients, namespace string, resource runtime.Object) error { daemonSet := resource.(*appsv1.DaemonSet) @@ -386,6 +556,12 @@ func UpdateDaemonSet(clients kube.Clients, namespace string, resource runtime.Ob return err } +func PatchDaemonSet(clients kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + daemonSet := resource.(*appsv1.DaemonSet) + _, err := clients.KubernetesClient.AppsV1().DaemonSets(namespace).Patch(context.TODO(), daemonSet.Name, patchType, bytes, meta_v1.PatchOptions{FieldManager: "Reloader"}) + return err +} + // UpdateStatefulSet performs rolling upgrade on statefulSet func UpdateStatefulSet(clients kube.Clients, namespace string, resource runtime.Object) error { statefulSet := resource.(*appsv1.StatefulSet) @@ -393,6 +569,12 @@ func UpdateStatefulSet(clients kube.Clients, namespace string, resource runtime. return err } +func PatchStatefulSet(clients kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + statefulSet := resource.(*appsv1.StatefulSet) + _, err := clients.KubernetesClient.AppsV1().StatefulSets(namespace).Patch(context.TODO(), statefulSet.Name, patchType, bytes, meta_v1.PatchOptions{FieldManager: "Reloader"}) + return err +} + // UpdateDeploymentConfig performs rolling upgrade on deploymentConfig func UpdateDeploymentConfig(clients kube.Clients, namespace string, resource runtime.Object) error { deploymentConfig := resource.(*openshiftv1.DeploymentConfig) @@ -400,11 +582,17 @@ func UpdateDeploymentConfig(clients kube.Clients, namespace string, resource run return err } +func PatchDeploymentConfig(clients kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + deploymentConfig := resource.(*openshiftv1.DeploymentConfig) + _, err := clients.OpenshiftAppsClient.AppsV1().DeploymentConfigs(namespace).Patch(context.TODO(), deploymentConfig.Name, patchType, bytes, meta_v1.PatchOptions{FieldManager: "Reloader"}) + return err +} + // UpdateRollout performs rolling upgrade on rollout func UpdateRollout(clients kube.Clients, namespace string, resource runtime.Object) error { - var err error rollout := resource.(*argorolloutv1alpha1.Rollout) strategy := rollout.GetAnnotations()[options.RolloutStrategyAnnotation] + var err error switch options.ToArgoRolloutStrategy(strategy) { case options.RestartStrategy: _, err = clients.ArgoRolloutClient.ArgoprojV1alpha1().Rollouts(namespace).Patch(context.TODO(), rollout.Name, patchtypes.MergePatchType, []byte(fmt.Sprintf(`{"spec": {"restartAt": "%s"}}`, time.Now().Format(time.RFC3339))), meta_v1.PatchOptions{FieldManager: "Reloader"}) @@ -414,6 +602,10 @@ func UpdateRollout(clients kube.Clients, namespace string, resource runtime.Obje return err } +func PatchRollout(clients kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + return errors.New("not supported patching: Rollout") +} + // GetDeploymentVolumes returns the Volumes of given deployment func GetDeploymentVolumes(item runtime.Object) []v1.Volume { return item.(*appsv1.Deployment).Spec.Template.Spec.Volumes diff --git a/internal/pkg/callbacks/rolling_upgrade_test.go b/internal/pkg/callbacks/rolling_upgrade_test.go index d358e21..5b6a5f1 100644 --- a/internal/pkg/callbacks/rolling_upgrade_test.go +++ b/internal/pkg/callbacks/rolling_upgrade_test.go @@ -3,6 +3,7 @@ package callbacks_test import ( "context" "fmt" + "strings" "testing" "time" @@ -10,7 +11,7 @@ import ( appsv1 "k8s.io/api/apps/v1" batchv1 "k8s.io/api/batch/v1" v1 "k8s.io/api/core/v1" - meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" watch "k8s.io/apimachinery/pkg/watch" @@ -18,6 +19,7 @@ import ( argorolloutv1alpha1 "github.com/argoproj/argo-rollouts/pkg/apis/rollouts/v1alpha1" fakeargoclientset "github.com/argoproj/argo-rollouts/pkg/client/clientset/versioned/fake" + patchtypes "k8s.io/apimachinery/pkg/types" "github.com/stakater/Reloader/internal/pkg/callbacks" "github.com/stakater/Reloader/internal/pkg/options" @@ -93,7 +95,7 @@ func TestUpdateRollout(t *testing.T) { t.Errorf("updating rollout: %v", err) } rollout, err = clients.ArgoRolloutClient.ArgoprojV1alpha1().Rollouts( - namespace).Get(context.TODO(), rollout.Name, meta_v1.GetOptions{}) + namespace).Get(context.TODO(), rollout.Name, metav1.GetOptions{}) if err != nil { t.Errorf("getting rollout: %v", err) @@ -111,6 +113,70 @@ func TestUpdateRollout(t *testing.T) { } } +func TestPatchRollout(t *testing.T) { + namespace := "test-ns" + rollout := testutil.GetRollout(namespace, "test", map[string]string{options.RolloutStrategyAnnotation: ""}) + err := callbacks.PatchRollout(clients, namespace, rollout, patchtypes.StrategicMergePatchType, []byte(`{"spec": {}}`)) + assert.EqualError(t, err, "not supported patching: Rollout") +} + +func TestResourceItem(t *testing.T) { + fixtures := newTestFixtures() + + tests := []struct { + name string + createFunc func(kube.Clients, string, string) (runtime.Object, error) + getItemFunc func(kube.Clients, string, string) (runtime.Object, error) + deleteFunc func(kube.Clients, string, string) error + }{ + { + name: "Deployment", + createFunc: createTestDeploymentWithAnnotations, + getItemFunc: callbacks.GetDeploymentItem, + deleteFunc: deleteTestDeployment, + }, + { + name: "CronJob", + createFunc: createTestCronJobWithAnnotations, + getItemFunc: callbacks.GetCronJobItem, + deleteFunc: deleteTestCronJob, + }, + { + name: "Job", + createFunc: createTestJobWithAnnotations, + getItemFunc: callbacks.GetJobItem, + deleteFunc: deleteTestJob, + }, + { + name: "DaemonSet", + createFunc: createTestDaemonSetWithAnnotations, + getItemFunc: callbacks.GetDaemonSetItem, + deleteFunc: deleteTestDaemonSet, + }, + { + name: "StatefulSet", + createFunc: createTestStatefulSetWithAnnotations, + getItemFunc: callbacks.GetStatefulSetItem, + deleteFunc: deleteTestStatefulSet, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resource, err := tt.createFunc(clients, fixtures.namespace, "1") + assert.NoError(t, err) + + accessor, err := meta.Accessor(resource) + assert.NoError(t, err) + + _, err = tt.getItemFunc(clients, accessor.GetName(), fixtures.namespace) + assert.NoError(t, err) + + tt.deleteFunc(clients, fixtures.namespace, accessor.GetName()) + }) + } +} + func TestResourceItems(t *testing.T) { fixtures := newTestFixtures() @@ -118,36 +184,42 @@ func TestResourceItems(t *testing.T) { name string createFunc func(kube.Clients, string) error getItemsFunc func(kube.Clients, string) []runtime.Object + deleteFunc func(kube.Clients, string) expectedCount int }{ { name: "Deployments", createFunc: createTestDeployments, getItemsFunc: callbacks.GetDeploymentItems, + deleteFunc: deleteTestDeployments, expectedCount: 2, }, { name: "CronJobs", createFunc: createTestCronJobs, getItemsFunc: callbacks.GetCronJobItems, + deleteFunc: deleteTestCronJobs, expectedCount: 2, }, { name: "Jobs", createFunc: createTestJobs, getItemsFunc: callbacks.GetJobItems, + deleteFunc: deleteTestJobs, expectedCount: 2, }, { name: "DaemonSets", createFunc: createTestDaemonSets, getItemsFunc: callbacks.GetDaemonSetItems, + deleteFunc: deleteTestDaemonSets, expectedCount: 2, }, { name: "StatefulSets", createFunc: createTestStatefulSets, getItemsFunc: callbacks.GetStatefulSetItems, + deleteFunc: deleteTestStatefulSets, expectedCount: 2, }, } @@ -262,10 +334,11 @@ func TestUpdateResources(t *testing.T) { name string createFunc func(kube.Clients, string, string) (runtime.Object, error) updateFunc func(kube.Clients, string, runtime.Object) error + deleteFunc func(kube.Clients, string, string) error }{ - {"Deployment", createTestDeploymentWithAnnotations, callbacks.UpdateDeployment}, - {"DaemonSet", createTestDaemonSetWithAnnotations, callbacks.UpdateDaemonSet}, - {"StatefulSet", createTestStatefulSetWithAnnotations, callbacks.UpdateStatefulSet}, + {"Deployment", createTestDeploymentWithAnnotations, callbacks.UpdateDeployment, deleteTestDeployment}, + {"DaemonSet", createTestDaemonSetWithAnnotations, callbacks.UpdateDaemonSet, deleteTestDaemonSet}, + {"StatefulSet", createTestStatefulSetWithAnnotations, callbacks.UpdateStatefulSet, deleteTestStatefulSet}, } for _, tt := range tests { @@ -275,6 +348,63 @@ func TestUpdateResources(t *testing.T) { err = tt.updateFunc(clients, fixtures.namespace, resource) assert.NoError(t, err) + + accessor, err := meta.Accessor(resource) + assert.NoError(t, err) + + tt.deleteFunc(clients, fixtures.namespace, accessor.GetName()) + }) + } +} + +func TestPatchResources(t *testing.T) { + fixtures := newTestFixtures() + + tests := []struct { + name string + createFunc func(kube.Clients, string, string) (runtime.Object, error) + patchFunc func(kube.Clients, string, runtime.Object, patchtypes.PatchType, []byte) error + deleteFunc func(kube.Clients, string, string) error + assertFunc func(err error) + }{ + {"Deployment", createTestDeploymentWithAnnotations, callbacks.PatchDeployment, deleteTestDeployment, func(err error) { + assert.NoError(t, err) + patchedResource, err := callbacks.GetDeploymentItem(clients, "test-deployment", fixtures.namespace) + assert.NoError(t, err) + assert.Equal(t, "test", patchedResource.(*appsv1.Deployment).ObjectMeta.Annotations["test"]) + }}, + {"DaemonSet", createTestDaemonSetWithAnnotations, callbacks.PatchDaemonSet, deleteTestDaemonSet, func(err error) { + assert.NoError(t, err) + patchedResource, err := callbacks.GetDaemonSetItem(clients, "test-daemonset", fixtures.namespace) + assert.NoError(t, err) + assert.Equal(t, "test", patchedResource.(*appsv1.DaemonSet).ObjectMeta.Annotations["test"]) + }}, + {"StatefulSet", createTestStatefulSetWithAnnotations, callbacks.PatchStatefulSet, deleteTestStatefulSet, func(err error) { + assert.NoError(t, err) + patchedResource, err := callbacks.GetStatefulSetItem(clients, "test-statefulset", fixtures.namespace) + assert.NoError(t, err) + assert.Equal(t, "test", patchedResource.(*appsv1.StatefulSet).ObjectMeta.Annotations["test"]) + }}, + {"CronJob", createTestCronJobWithAnnotations, callbacks.PatchCronJob, deleteTestCronJob, func(err error) { + assert.EqualError(t, err, "not supported patching: CronJob") + }}, + {"Job", createTestJobWithAnnotations, callbacks.PatchJob, deleteTestJob, func(err error) { + assert.EqualError(t, err, "not supported patching: Job") + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resource, err := tt.createFunc(clients, fixtures.namespace, "1") + assert.NoError(t, err) + + err = tt.patchFunc(clients, fixtures.namespace, resource, patchtypes.StrategicMergePatchType, []byte(`{"metadata":{"annotations":{"test":"test"}}}`)) + tt.assertFunc(err) + + accessor, err := meta.Accessor(resource) + assert.NoError(t, err) + + tt.deleteFunc(clients, fixtures.namespace, accessor.GetName()) }) } } @@ -287,6 +417,8 @@ func TestCreateJobFromCronjob(t *testing.T) { err = callbacks.CreateJobFromCronjob(clients, fixtures.namespace, cronJob.(*batchv1.CronJob)) assert.NoError(t, err) + + deleteTestCronJob(clients, fixtures.namespace, "test-cronjob") } func TestReCreateJobFromJob(t *testing.T) { @@ -297,6 +429,8 @@ func TestReCreateJobFromJob(t *testing.T) { err = callbacks.ReCreateJobFromjob(clients, fixtures.namespace, job.(*batchv1.Job)) assert.NoError(t, err) + + deleteTestJob(clients, fixtures.namespace, "test-cronjob") } func TestGetVolumes(t *testing.T) { @@ -321,6 +455,24 @@ func TestGetVolumes(t *testing.T) { } } +func TesGetPatchTemplateAnnotation(t *testing.T) { + templates := callbacks.GetPatchTemplates() + assert.NotEmpty(t, templates.AnnotationTemplate) + assert.Equal(t, 2, strings.Count(templates.AnnotationTemplate, "%s")) +} + +func TestGetPatchTemplateEnvVar(t *testing.T) { + templates := callbacks.GetPatchTemplates() + assert.NotEmpty(t, templates.EnvVarTemplate) + assert.Equal(t, 3, strings.Count(templates.EnvVarTemplate, "%s")) +} + +func TestGetPatchDeleteTemplateEnvVar(t *testing.T) { + templates := callbacks.GetPatchTemplates() + assert.NotEmpty(t, templates.DeleteEnvVarTemplate) + assert.Equal(t, 2, strings.Count(templates.DeleteEnvVarTemplate, "%d")) +} + // Helper functions func isRestartStrategy(rollout *argorolloutv1alpha1.Rollout) bool { @@ -330,7 +482,7 @@ func isRestartStrategy(rollout *argorolloutv1alpha1.Rollout) bool { func watchRollout(name, namespace string) chan interface{} { timeOut := int64(1) modifiedChan := make(chan interface{}) - watcher, _ := clients.ArgoRolloutClient.ArgoprojV1alpha1().Rollouts(namespace).Watch(context.Background(), meta_v1.ListOptions{TimeoutSeconds: &timeOut}) + watcher, _ := clients.ArgoRolloutClient.ArgoprojV1alpha1().Rollouts(namespace).Watch(context.Background(), metav1.ListOptions{TimeoutSeconds: &timeOut}) go watchModified(watcher, name, modifiedChan) return modifiedChan } @@ -358,6 +510,12 @@ func createTestDeployments(clients kube.Clients, namespace string) error { return nil } +func deleteTestDeployments(clients kube.Clients, namespace string) { + for i := 1; i <= 2; i++ { + testutil.DeleteDeployment(clients.KubernetesClient, namespace, fmt.Sprintf("test-deployment-%d", i)) + } +} + func createTestCronJobs(clients kube.Clients, namespace string) error { for i := 1; i <= 2; i++ { _, err := testutil.CreateCronJob(clients.KubernetesClient, fmt.Sprintf("test-cron-%d", i), namespace, false) @@ -368,6 +526,12 @@ func createTestCronJobs(clients kube.Clients, namespace string) error { return nil } +func deleteTestCronJobs(clients kube.Clients, namespace string) { + for i := 1; i <= 2; i++ { + testutil.DeleteCronJob(clients.KubernetesClient, namespace, fmt.Sprintf("test-cron-%d", i)) + } +} + func createTestJobs(clients kube.Clients, namespace string) error { for i := 1; i <= 2; i++ { _, err := testutil.CreateJob(clients.KubernetesClient, fmt.Sprintf("test-job-%d", i), namespace, false) @@ -378,6 +542,12 @@ func createTestJobs(clients kube.Clients, namespace string) error { return nil } +func deleteTestJobs(clients kube.Clients, namespace string) { + for i := 1; i <= 2; i++ { + testutil.DeleteJob(clients.KubernetesClient, namespace, fmt.Sprintf("test-job-%d", i)) + } +} + func createTestDaemonSets(clients kube.Clients, namespace string) error { for i := 1; i <= 2; i++ { _, err := testutil.CreateDaemonSet(clients.KubernetesClient, fmt.Sprintf("test-daemonset-%d", i), namespace, false) @@ -388,6 +558,12 @@ func createTestDaemonSets(clients kube.Clients, namespace string) error { return nil } +func deleteTestDaemonSets(clients kube.Clients, namespace string) { + for i := 1; i <= 2; i++ { + testutil.DeleteDaemonSet(clients.KubernetesClient, namespace, fmt.Sprintf("test-daemonset-%d", i)) + } +} + func createTestStatefulSets(clients kube.Clients, namespace string) error { for i := 1; i <= 2; i++ { _, err := testutil.CreateStatefulSet(clients.KubernetesClient, fmt.Sprintf("test-statefulset-%d", i), namespace, false) @@ -398,6 +574,12 @@ func createTestStatefulSets(clients kube.Clients, namespace string) error { return nil } +func deleteTestStatefulSets(clients kube.Clients, namespace string) { + for i := 1; i <= 2; i++ { + testutil.DeleteStatefulSet(clients.KubernetesClient, namespace, fmt.Sprintf("test-statefulset-%d", i)) + } +} + func createResourceWithPodAnnotations(obj runtime.Object, annotations map[string]string) runtime.Object { switch v := obj.(type) { case *appsv1.Deployment: @@ -479,6 +661,10 @@ func createTestDeploymentWithAnnotations(clients kube.Clients, namespace, versio return clients.KubernetesClient.AppsV1().Deployments(namespace).Create(context.TODO(), deployment, metav1.CreateOptions{}) } +func deleteTestDeployment(clients kube.Clients, namespace, name string) error { + return clients.KubernetesClient.AppsV1().Deployments(namespace).Delete(context.TODO(), name, metav1.DeleteOptions{}) +} + func createTestDaemonSetWithAnnotations(clients kube.Clients, namespace, version string) (runtime.Object, error) { daemonSet := &appsv1.DaemonSet{ ObjectMeta: metav1.ObjectMeta{ @@ -490,6 +676,10 @@ func createTestDaemonSetWithAnnotations(clients kube.Clients, namespace, version return clients.KubernetesClient.AppsV1().DaemonSets(namespace).Create(context.TODO(), daemonSet, metav1.CreateOptions{}) } +func deleteTestDaemonSet(clients kube.Clients, namespace, name string) error { + return clients.KubernetesClient.AppsV1().DaemonSets(namespace).Delete(context.TODO(), name, metav1.DeleteOptions{}) +} + func createTestStatefulSetWithAnnotations(clients kube.Clients, namespace, version string) (runtime.Object, error) { statefulSet := &appsv1.StatefulSet{ ObjectMeta: metav1.ObjectMeta{ @@ -501,6 +691,10 @@ func createTestStatefulSetWithAnnotations(clients kube.Clients, namespace, versi return clients.KubernetesClient.AppsV1().StatefulSets(namespace).Create(context.TODO(), statefulSet, metav1.CreateOptions{}) } +func deleteTestStatefulSet(clients kube.Clients, namespace, name string) error { + return clients.KubernetesClient.AppsV1().StatefulSets(namespace).Delete(context.TODO(), name, metav1.DeleteOptions{}) +} + func createTestCronJobWithAnnotations(clients kube.Clients, namespace, version string) (runtime.Object, error) { cronJob := &batchv1.CronJob{ ObjectMeta: metav1.ObjectMeta{ @@ -512,6 +706,10 @@ func createTestCronJobWithAnnotations(clients kube.Clients, namespace, version s return clients.KubernetesClient.BatchV1().CronJobs(namespace).Create(context.TODO(), cronJob, metav1.CreateOptions{}) } +func deleteTestCronJob(clients kube.Clients, namespace, name string) error { + return clients.KubernetesClient.BatchV1().CronJobs(namespace).Delete(context.TODO(), name, metav1.DeleteOptions{}) +} + func createTestJobWithAnnotations(clients kube.Clients, namespace, version string) (runtime.Object, error) { job := &batchv1.Job{ ObjectMeta: metav1.ObjectMeta{ @@ -522,3 +720,7 @@ func createTestJobWithAnnotations(clients kube.Clients, namespace, version strin } return clients.KubernetesClient.BatchV1().Jobs(namespace).Create(context.TODO(), job, metav1.CreateOptions{}) } + +func deleteTestJob(clients kube.Clients, namespace, name string) error { + return clients.KubernetesClient.BatchV1().Jobs(namespace).Delete(context.TODO(), name, metav1.DeleteOptions{}) +} diff --git a/internal/pkg/handler/delete.go b/internal/pkg/handler/delete.go index 2378d0f..772cfca 100644 --- a/internal/pkg/handler/delete.go +++ b/internal/pkg/handler/delete.go @@ -1,6 +1,9 @@ package handler import ( + "fmt" + "slices" + "github.com/sirupsen/logrus" "github.com/stakater/Reloader/internal/pkg/callbacks" "github.com/stakater/Reloader/internal/pkg/constants" @@ -8,8 +11,10 @@ import ( "github.com/stakater/Reloader/internal/pkg/options" "github.com/stakater/Reloader/internal/pkg/testutil" "github.com/stakater/Reloader/internal/pkg/util" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" + patchtypes "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" ) @@ -50,7 +55,7 @@ func (r ResourceDeleteHandler) GetConfig() (util.Config, string) { return config, oldSHAData } -func invokeDeleteStrategy(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) constants.Result { +func invokeDeleteStrategy(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) InvokeStrategyResult { if options.ReloadStrategy == constants.AnnotationsReloadStrategy { return removePodAnnotations(upgradeFuncs, item, config, autoReload) } @@ -58,35 +63,38 @@ func invokeDeleteStrategy(upgradeFuncs callbacks.RollingUpgradeFuncs, item runti return removeContainerEnvVars(upgradeFuncs, item, config, autoReload) } -func removePodAnnotations(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) constants.Result { +func removePodAnnotations(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) InvokeStrategyResult { config.SHAValue = testutil.GetSHAfromEmptyData() return updatePodAnnotations(upgradeFuncs, item, config, autoReload) } -func removeContainerEnvVars(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) constants.Result { +func removeContainerEnvVars(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) InvokeStrategyResult { envVar := getEnvVarName(config.ResourceName, config.Type) container := getContainerUsingResource(upgradeFuncs, item, config, autoReload) if container == nil { - return constants.NoContainerFound + return InvokeStrategyResult{constants.NoContainerFound, nil} } //remove if env var exists - containers := upgradeFuncs.ContainersFunc(item) - for i := range containers { - envs := containers[i].Env - index := -1 - for j := range envs { - if envs[j].Name == envVar { - index = j - break - } - } + if len(container.Env) > 0 { + index := slices.IndexFunc(container.Env, func(envVariable v1.EnvVar) bool { + return envVariable.Name == envVar + }) if index != -1 { - containers[i].Env = append(containers[i].Env[:index], containers[i].Env[index+1:]...) - return constants.Updated + var patch []byte + if upgradeFuncs.SupportsPatch { + containers := upgradeFuncs.ContainersFunc(item) + containerIndex := slices.IndexFunc(containers, func(c v1.Container) bool { + return c.Name == container.Name + }) + patch = fmt.Appendf(nil, upgradeFuncs.PatchTemplatesFunc().DeleteEnvVarTemplate, containerIndex, index) + } + + container.Env = append(container.Env[:index], container.Env[index+1:]...) + return InvokeStrategyResult{constants.Updated, &Patch{Type: patchtypes.JSONPatchType, Bytes: patch}} } } - return constants.NotUpdated + return InvokeStrategyResult{constants.NotUpdated, nil} } diff --git a/internal/pkg/handler/upgrade.go b/internal/pkg/handler/upgrade.go index 8365fb5..508f926 100644 --- a/internal/pkg/handler/upgrade.go +++ b/internal/pkg/handler/upgrade.go @@ -24,104 +24,134 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/runtime" + patchtypes "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" + "k8s.io/client-go/util/retry" ) // GetDeploymentRollingUpgradeFuncs returns all callback funcs for a deployment func GetDeploymentRollingUpgradeFuncs() callbacks.RollingUpgradeFuncs { return callbacks.RollingUpgradeFuncs{ + ItemFunc: callbacks.GetDeploymentItem, ItemsFunc: callbacks.GetDeploymentItems, AnnotationsFunc: callbacks.GetDeploymentAnnotations, PodAnnotationsFunc: callbacks.GetDeploymentPodAnnotations, ContainersFunc: callbacks.GetDeploymentContainers, InitContainersFunc: callbacks.GetDeploymentInitContainers, UpdateFunc: callbacks.UpdateDeployment, + PatchFunc: callbacks.PatchDeployment, + PatchTemplatesFunc: callbacks.GetPatchTemplates, VolumesFunc: callbacks.GetDeploymentVolumes, ResourceType: "Deployment", + SupportsPatch: true, } } // GetDeploymentRollingUpgradeFuncs returns all callback funcs for a cronjob func GetCronJobCreateJobFuncs() callbacks.RollingUpgradeFuncs { return callbacks.RollingUpgradeFuncs{ + ItemFunc: callbacks.GetCronJobItem, ItemsFunc: callbacks.GetCronJobItems, AnnotationsFunc: callbacks.GetCronJobAnnotations, PodAnnotationsFunc: callbacks.GetCronJobPodAnnotations, ContainersFunc: callbacks.GetCronJobContainers, InitContainersFunc: callbacks.GetCronJobInitContainers, UpdateFunc: callbacks.CreateJobFromCronjob, + PatchFunc: callbacks.PatchCronJob, + PatchTemplatesFunc: func() callbacks.PatchTemplates { return callbacks.PatchTemplates{} }, VolumesFunc: callbacks.GetCronJobVolumes, ResourceType: "CronJob", + SupportsPatch: false, } } // GetDeploymentRollingUpgradeFuncs returns all callback funcs for a cronjob func GetJobCreateJobFuncs() callbacks.RollingUpgradeFuncs { return callbacks.RollingUpgradeFuncs{ + ItemFunc: callbacks.GetJobItem, ItemsFunc: callbacks.GetJobItems, AnnotationsFunc: callbacks.GetJobAnnotations, PodAnnotationsFunc: callbacks.GetJobPodAnnotations, ContainersFunc: callbacks.GetJobContainers, InitContainersFunc: callbacks.GetJobInitContainers, UpdateFunc: callbacks.ReCreateJobFromjob, + PatchFunc: callbacks.PatchJob, + PatchTemplatesFunc: func() callbacks.PatchTemplates { return callbacks.PatchTemplates{} }, VolumesFunc: callbacks.GetJobVolumes, ResourceType: "Job", + SupportsPatch: false, } } // GetDaemonSetRollingUpgradeFuncs returns all callback funcs for a daemonset func GetDaemonSetRollingUpgradeFuncs() callbacks.RollingUpgradeFuncs { return callbacks.RollingUpgradeFuncs{ + ItemFunc: callbacks.GetDaemonSetItem, ItemsFunc: callbacks.GetDaemonSetItems, AnnotationsFunc: callbacks.GetDaemonSetAnnotations, PodAnnotationsFunc: callbacks.GetDaemonSetPodAnnotations, ContainersFunc: callbacks.GetDaemonSetContainers, InitContainersFunc: callbacks.GetDaemonSetInitContainers, UpdateFunc: callbacks.UpdateDaemonSet, + PatchFunc: callbacks.PatchDaemonSet, + PatchTemplatesFunc: callbacks.GetPatchTemplates, VolumesFunc: callbacks.GetDaemonSetVolumes, ResourceType: "DaemonSet", + SupportsPatch: true, } } // GetStatefulSetRollingUpgradeFuncs returns all callback funcs for a statefulSet func GetStatefulSetRollingUpgradeFuncs() callbacks.RollingUpgradeFuncs { return callbacks.RollingUpgradeFuncs{ + ItemFunc: callbacks.GetStatefulSetItem, ItemsFunc: callbacks.GetStatefulSetItems, AnnotationsFunc: callbacks.GetStatefulSetAnnotations, PodAnnotationsFunc: callbacks.GetStatefulSetPodAnnotations, ContainersFunc: callbacks.GetStatefulSetContainers, InitContainersFunc: callbacks.GetStatefulSetInitContainers, UpdateFunc: callbacks.UpdateStatefulSet, + PatchFunc: callbacks.PatchStatefulSet, + PatchTemplatesFunc: callbacks.GetPatchTemplates, VolumesFunc: callbacks.GetStatefulSetVolumes, ResourceType: "StatefulSet", + SupportsPatch: true, } } // GetDeploymentConfigRollingUpgradeFuncs returns all callback funcs for a deploymentConfig func GetDeploymentConfigRollingUpgradeFuncs() callbacks.RollingUpgradeFuncs { return callbacks.RollingUpgradeFuncs{ + ItemFunc: callbacks.GetDeploymentConfigItem, ItemsFunc: callbacks.GetDeploymentConfigItems, AnnotationsFunc: callbacks.GetDeploymentConfigAnnotations, PodAnnotationsFunc: callbacks.GetDeploymentConfigPodAnnotations, ContainersFunc: callbacks.GetDeploymentConfigContainers, InitContainersFunc: callbacks.GetDeploymentConfigInitContainers, UpdateFunc: callbacks.UpdateDeploymentConfig, + PatchFunc: callbacks.PatchDeploymentConfig, + PatchTemplatesFunc: callbacks.GetPatchTemplates, VolumesFunc: callbacks.GetDeploymentConfigVolumes, ResourceType: "DeploymentConfig", + SupportsPatch: true, } } // GetArgoRolloutRollingUpgradeFuncs returns all callback funcs for a rollout func GetArgoRolloutRollingUpgradeFuncs() callbacks.RollingUpgradeFuncs { return callbacks.RollingUpgradeFuncs{ + ItemFunc: callbacks.GetRolloutItem, ItemsFunc: callbacks.GetRolloutItems, AnnotationsFunc: callbacks.GetRolloutAnnotations, PodAnnotationsFunc: callbacks.GetRolloutPodAnnotations, ContainersFunc: callbacks.GetRolloutContainers, InitContainersFunc: callbacks.GetRolloutInitContainers, UpdateFunc: callbacks.UpdateRollout, + PatchFunc: callbacks.PatchRollout, + PatchTemplatesFunc: func() callbacks.PatchTemplates { return callbacks.PatchTemplates{} }, VolumesFunc: callbacks.GetRolloutVolumes, ResourceType: "Rollout", + SupportsPatch: false, } } @@ -210,107 +240,131 @@ func rollingUpgrade(clients kube.Clients, config util.Config, upgradeFuncs callb func PerformAction(clients kube.Clients, config util.Config, upgradeFuncs callbacks.RollingUpgradeFuncs, collectors metrics.Collectors, recorder record.EventRecorder, strategy invokeStrategy) error { items := upgradeFuncs.ItemsFunc(clients, config.Namespace) - for _, i := range items { - // find correct annotation and update the resource - annotations := upgradeFuncs.AnnotationsFunc(i) - annotationValue, found := annotations[config.Annotation] - searchAnnotationValue, foundSearchAnn := annotations[options.AutoSearchAnnotation] - reloaderEnabledValue, foundAuto := annotations[options.ReloaderAutoAnnotation] - typedAutoAnnotationEnabledValue, foundTypedAuto := annotations[config.TypedAutoAnnotation] - excludeConfigmapAnnotationValue, foundExcludeConfigmap := annotations[options.ConfigmapExcludeReloaderAnnotation] - excludeSecretAnnotationValue, foundExcludeSecret := annotations[options.SecretExcludeReloaderAnnotation] + for _, item := range items { + err := retry.RetryOnConflict(retry.DefaultRetry, func() error { + return upgradeResource(clients, config, upgradeFuncs, collectors, recorder, strategy, item) + }) - if !found && !foundAuto && !foundTypedAuto && !foundSearchAnn { - annotations = upgradeFuncs.PodAnnotationsFunc(i) - annotationValue = annotations[config.Annotation] - searchAnnotationValue = annotations[options.AutoSearchAnnotation] - reloaderEnabledValue = annotations[options.ReloaderAutoAnnotation] - typedAutoAnnotationEnabledValue = annotations[config.TypedAutoAnnotation] + if err != nil { + return err } + } - isResourceExcluded := false + return nil +} - switch config.Type { - case constants.ConfigmapEnvVarPostfix: - if foundExcludeConfigmap { - isResourceExcluded = checkIfResourceIsExcluded(config.ResourceName, excludeConfigmapAnnotationValue) - } - case constants.SecretEnvVarPostfix: - if foundExcludeSecret { - isResourceExcluded = checkIfResourceIsExcluded(config.ResourceName, excludeSecretAnnotationValue) - } +func upgradeResource(clients kube.Clients, config util.Config, upgradeFuncs callbacks.RollingUpgradeFuncs, collectors metrics.Collectors, recorder record.EventRecorder, strategy invokeStrategy, resource runtime.Object) error { + accessor, err := meta.Accessor(resource) + if err != nil { + return err + } + + resourceName := accessor.GetName() + resource, err = upgradeFuncs.ItemFunc(clients, resourceName, config.Namespace) + if err != nil { + return err + } + + // find correct annotation and update the resource + annotations := upgradeFuncs.AnnotationsFunc(resource) + annotationValue, found := annotations[config.Annotation] + searchAnnotationValue, foundSearchAnn := annotations[options.AutoSearchAnnotation] + reloaderEnabledValue, foundAuto := annotations[options.ReloaderAutoAnnotation] + typedAutoAnnotationEnabledValue, foundTypedAuto := annotations[config.TypedAutoAnnotation] + excludeConfigmapAnnotationValue, foundExcludeConfigmap := annotations[options.ConfigmapExcludeReloaderAnnotation] + excludeSecretAnnotationValue, foundExcludeSecret := annotations[options.SecretExcludeReloaderAnnotation] + + if !found && !foundAuto && !foundTypedAuto && !foundSearchAnn { + annotations = upgradeFuncs.PodAnnotationsFunc(resource) + annotationValue = annotations[config.Annotation] + searchAnnotationValue = annotations[options.AutoSearchAnnotation] + reloaderEnabledValue = annotations[options.ReloaderAutoAnnotation] + typedAutoAnnotationEnabledValue = annotations[config.TypedAutoAnnotation] + } + + isResourceExcluded := false + + switch config.Type { + case constants.ConfigmapEnvVarPostfix: + if foundExcludeConfigmap { + isResourceExcluded = checkIfResourceIsExcluded(config.ResourceName, excludeConfigmapAnnotationValue) } - - if isResourceExcluded { - continue + case constants.SecretEnvVarPostfix: + if foundExcludeSecret { + isResourceExcluded = checkIfResourceIsExcluded(config.ResourceName, excludeSecretAnnotationValue) } + } - result := constants.NotUpdated - reloaderEnabled, _ := strconv.ParseBool(reloaderEnabledValue) - typedAutoAnnotationEnabled, _ := strconv.ParseBool(typedAutoAnnotationEnabledValue) - if reloaderEnabled || typedAutoAnnotationEnabled || reloaderEnabledValue == "" && typedAutoAnnotationEnabledValue == "" && options.AutoReloadAll { - result = strategy(upgradeFuncs, i, config, true) - } + if isResourceExcluded { + return nil + } - if result != constants.Updated && annotationValue != "" { - values := strings.Split(annotationValue, ",") - for _, value := range values { - value = strings.TrimSpace(value) - re := regexp.MustCompile("^" + value + "$") - if re.Match([]byte(config.ResourceName)) { - result = strategy(upgradeFuncs, i, config, false) - if result == constants.Updated { - break - } - } - } - } + strategyResult := InvokeStrategyResult{constants.NotUpdated, nil} + reloaderEnabled, _ := strconv.ParseBool(reloaderEnabledValue) + typedAutoAnnotationEnabled, _ := strconv.ParseBool(typedAutoAnnotationEnabledValue) + if reloaderEnabled || typedAutoAnnotationEnabled || reloaderEnabledValue == "" && typedAutoAnnotationEnabledValue == "" && options.AutoReloadAll { + strategyResult = strategy(upgradeFuncs, resource, config, true) + } - if result != constants.Updated && searchAnnotationValue == "true" { - matchAnnotationValue := config.ResourceAnnotations[options.SearchMatchAnnotation] - if matchAnnotationValue == "true" { - result = strategy(upgradeFuncs, i, config, true) - } - } - - if result == constants.Updated { - accessor, err := meta.Accessor(i) - if err != nil { - return err - } - resourceName := accessor.GetName() - err = upgradeFuncs.UpdateFunc(clients, config.Namespace, i) - if err != nil { - message := fmt.Sprintf("Update for '%s' of type '%s' in namespace '%s' failed with error %v", resourceName, upgradeFuncs.ResourceType, config.Namespace, err) - logrus.Errorf("Update for '%s' of type '%s' in namespace '%s' failed with error %v", resourceName, upgradeFuncs.ResourceType, config.Namespace, err) - - collectors.Reloaded.With(prometheus.Labels{"success": "false"}).Inc() - collectors.ReloadedByNamespace.With(prometheus.Labels{"success": "false", "namespace": config.Namespace}).Inc() - if recorder != nil { - recorder.Event(i, v1.EventTypeWarning, "ReloadFail", message) - } - return err - } else { - message := fmt.Sprintf("Changes detected in '%s' of type '%s' in namespace '%s'", config.ResourceName, config.Type, config.Namespace) - message += fmt.Sprintf(", Updated '%s' of type '%s' in namespace '%s'", resourceName, upgradeFuncs.ResourceType, config.Namespace) - - logrus.Infof("Changes detected in '%s' of type '%s' in namespace '%s'; updated '%s' of type '%s' in namespace '%s'", config.ResourceName, config.Type, config.Namespace, resourceName, upgradeFuncs.ResourceType, config.Namespace) - - collectors.Reloaded.With(prometheus.Labels{"success": "true"}).Inc() - collectors.ReloadedByNamespace.With(prometheus.Labels{"success": "true", "namespace": config.Namespace}).Inc() - alert_on_reload, ok := os.LookupEnv("ALERT_ON_RELOAD") - if recorder != nil { - recorder.Event(i, v1.EventTypeNormal, "Reloaded", message) - } - if ok && alert_on_reload == "true" { - msg := fmt.Sprintf( - "Reloader detected changes in *%s* of type *%s* in namespace *%s*. Hence reloaded *%s* of type *%s* in namespace *%s*", - config.ResourceName, config.Type, config.Namespace, resourceName, upgradeFuncs.ResourceType, config.Namespace) - alert.SendWebhookAlert(msg) + if strategyResult.Result != constants.Updated && annotationValue != "" { + values := strings.Split(annotationValue, ",") + for _, value := range values { + value = strings.TrimSpace(value) + re := regexp.MustCompile("^" + value + "$") + if re.Match([]byte(config.ResourceName)) { + strategyResult = strategy(upgradeFuncs, resource, config, false) + if strategyResult.Result == constants.Updated { + break } } } } + + if strategyResult.Result != constants.Updated && searchAnnotationValue == "true" { + matchAnnotationValue := config.ResourceAnnotations[options.SearchMatchAnnotation] + if matchAnnotationValue == "true" { + strategyResult = strategy(upgradeFuncs, resource, config, true) + } + } + if strategyResult.Result == constants.Updated { + var err error + if upgradeFuncs.SupportsPatch && strategyResult.Patch != nil { + err = upgradeFuncs.PatchFunc(clients, config.Namespace, resource, strategyResult.Patch.Type, strategyResult.Patch.Bytes) + } else { + err = upgradeFuncs.UpdateFunc(clients, config.Namespace, resource) + } + + if err != nil { + message := fmt.Sprintf("Update for '%s' of type '%s' in namespace '%s' failed with error %v", resourceName, upgradeFuncs.ResourceType, config.Namespace, err) + logrus.Errorf("Update for '%s' of type '%s' in namespace '%s' failed with error %v", resourceName, upgradeFuncs.ResourceType, config.Namespace, err) + + collectors.Reloaded.With(prometheus.Labels{"success": "false"}).Inc() + collectors.ReloadedByNamespace.With(prometheus.Labels{"success": "false", "namespace": config.Namespace}).Inc() + if recorder != nil { + recorder.Event(resource, v1.EventTypeWarning, "ReloadFail", message) + } + return err + } else { + message := fmt.Sprintf("Changes detected in '%s' of type '%s' in namespace '%s'", config.ResourceName, config.Type, config.Namespace) + message += fmt.Sprintf(", Updated '%s' of type '%s' in namespace '%s'", resourceName, upgradeFuncs.ResourceType, config.Namespace) + + logrus.Infof("Changes detected in '%s' of type '%s' in namespace '%s'; updated '%s' of type '%s' in namespace '%s'", config.ResourceName, config.Type, config.Namespace, resourceName, upgradeFuncs.ResourceType, config.Namespace) + + collectors.Reloaded.With(prometheus.Labels{"success": "true"}).Inc() + collectors.ReloadedByNamespace.With(prometheus.Labels{"success": "true", "namespace": config.Namespace}).Inc() + alert_on_reload, ok := os.LookupEnv("ALERT_ON_RELOAD") + if recorder != nil { + recorder.Event(resource, v1.EventTypeNormal, "Reloaded", message) + } + if ok && alert_on_reload == "true" { + msg := fmt.Sprintf( + "Reloader detected changes in *%s* of type *%s* in namespace *%s*. Hence reloaded *%s* of type *%s* in namespace *%s*", + config.ResourceName, config.Type, config.Namespace, resourceName, upgradeFuncs.ResourceType, config.Namespace) + alert.SendWebhookAlert(msg) + } + } + } + return nil } @@ -439,42 +493,51 @@ func getContainerUsingResource(upgradeFuncs callbacks.RollingUpgradeFuncs, item return container } -type invokeStrategy func(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) constants.Result +type Patch struct { + Type patchtypes.PatchType + Bytes []byte +} -func invokeReloadStrategy(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) constants.Result { +type InvokeStrategyResult struct { + Result constants.Result + Patch *Patch +} + +type invokeStrategy func(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) InvokeStrategyResult + +func invokeReloadStrategy(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) InvokeStrategyResult { if options.ReloadStrategy == constants.AnnotationsReloadStrategy { return updatePodAnnotations(upgradeFuncs, item, config, autoReload) } - return updateContainerEnvVars(upgradeFuncs, item, config, autoReload) } -func updatePodAnnotations(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) constants.Result { +func updatePodAnnotations(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) InvokeStrategyResult { container := getContainerUsingResource(upgradeFuncs, item, config, autoReload) if container == nil { - return constants.NoContainerFound + return InvokeStrategyResult{constants.NoContainerFound, nil} } // Generate reloaded annotations. Attaching this to the item's annotation will trigger a rollout // Note: the data on this struct is purely informational and is not used for future updates reloadSource := util.NewReloadSourceFromConfig(config, []string{container.Name}) - annotations, err := createReloadedAnnotations(&reloadSource) + annotations, patch, err := createReloadedAnnotations(&reloadSource, upgradeFuncs) if err != nil { logrus.Errorf("Failed to create reloaded annotations for %s! error = %v", config.ResourceName, err) - return constants.NotUpdated + return InvokeStrategyResult{constants.NotUpdated, nil} } // Copy the all annotations to the item's annotations pa := upgradeFuncs.PodAnnotationsFunc(item) if pa == nil { - return constants.NotUpdated + return InvokeStrategyResult{constants.NotUpdated, nil} } for k, v := range annotations { pa[k] = v } - return constants.Updated + return InvokeStrategyResult{constants.Updated, &Patch{Type: patchtypes.StrategicMergePatchType, Bytes: patch}} } func getReloaderAnnotationKey() string { @@ -484,9 +547,9 @@ func getReloaderAnnotationKey() string { ) } -func createReloadedAnnotations(target *util.ReloadSource) (map[string]string, error) { +func createReloadedAnnotations(target *util.ReloadSource, upgradeFuncs callbacks.RollingUpgradeFuncs) (map[string]string, []byte, error) { if target == nil { - return nil, errors.New("target is required") + return nil, nil, errors.New("target is required") } // Create a single "last-invokeReloadStrategy-from" annotation that stores metadata about the @@ -498,53 +561,76 @@ func createReloadedAnnotations(target *util.ReloadSource) (map[string]string, er lastReloadedResource, err := json.Marshal(target) if err != nil { - return nil, err + return nil, nil, err } annotations[lastReloadedResourceName] = string(lastReloadedResource) - return annotations, nil + + var patch []byte + if upgradeFuncs.SupportsPatch { + escapedValue, err := jsonEscape(annotations[lastReloadedResourceName]) + if err != nil { + return nil, nil, err + } + patch = fmt.Appendf(nil, upgradeFuncs.PatchTemplatesFunc().AnnotationTemplate, lastReloadedResourceName, escapedValue) + } + + return annotations, patch, nil } func getEnvVarName(resourceName string, typeName string) string { return constants.EnvVarPrefix + util.ConvertToEnvVarName(resourceName) + "_" + typeName } -func updateContainerEnvVars(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) constants.Result { - var result constants.Result +func updateContainerEnvVars(upgradeFuncs callbacks.RollingUpgradeFuncs, item runtime.Object, config util.Config, autoReload bool) InvokeStrategyResult { envVar := getEnvVarName(config.ResourceName, config.Type) container := getContainerUsingResource(upgradeFuncs, item, config, autoReload) if container == nil { - return constants.NoContainerFound + return InvokeStrategyResult{constants.NoContainerFound, nil} } //update if env var exists - result = updateEnvVar(upgradeFuncs.ContainersFunc(item), envVar, config.SHAValue) + updateResult := updateEnvVar(container, envVar, config.SHAValue) // if no existing env var exists lets create one - if result == constants.NoEnvVarFound { + if updateResult == constants.NoEnvVarFound { e := v1.EnvVar{ Name: envVar, Value: config.SHAValue, } container.Env = append(container.Env, e) - result = constants.Updated + updateResult = constants.Updated } - return result + + var patch []byte + if upgradeFuncs.SupportsPatch { + patch = fmt.Appendf(nil, upgradeFuncs.PatchTemplatesFunc().EnvVarTemplate, container.Name, envVar, config.SHAValue) + } + + return InvokeStrategyResult{updateResult, &Patch{Type: patchtypes.StrategicMergePatchType, Bytes: patch}} } -func updateEnvVar(containers []v1.Container, envVar string, shaData string) constants.Result { - for i := range containers { - envs := containers[i].Env - for j := range envs { - if envs[j].Name == envVar { - if envs[j].Value != shaData { - envs[j].Value = shaData - return constants.Updated - } - return constants.NotUpdated +func updateEnvVar(container *v1.Container, envVar string, shaData string) constants.Result { + envs := container.Env + for j := range envs { + if envs[j].Name == envVar { + if envs[j].Value != shaData { + envs[j].Value = shaData + return constants.Updated } + return constants.NotUpdated } } + return constants.NoEnvVarFound } + +func jsonEscape(toEscape string) (string, error) { + bytes, err := json.Marshal(toEscape) + if err != nil { + return "", err + } + escaped := string(bytes) + return escaped[1 : len(escaped)-1], nil +} diff --git a/internal/pkg/handler/upgrade_test.go b/internal/pkg/handler/upgrade_test.go index 2b71740..9817f21 100644 --- a/internal/pkg/handler/upgrade_test.go +++ b/internal/pkg/handler/upgrade_test.go @@ -17,9 +17,12 @@ import ( "github.com/stakater/Reloader/internal/pkg/testutil" "github.com/stakater/Reloader/internal/pkg/util" "github.com/stakater/Reloader/pkg/kube" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + patchtypes "k8s.io/apimachinery/pkg/types" testclient "k8s.io/client-go/kubernetes/fake" ) @@ -1413,6 +1416,22 @@ func testRollingUpgradeInvokeDeleteStrategyArs(t *testing.T, clients kube.Client } } +func testRollingUpgradeWithPatchAndInvokeDeleteStrategyArs(t *testing.T, clients kube.Clients, config util.Config, upgradeFuncs callbacks.RollingUpgradeFuncs, collectors metrics.Collectors, envVarPostfix string) { + err := PerformAction(clients, config, upgradeFuncs, collectors, nil, invokeDeleteStrategy) + upgradeFuncs.PatchFunc = func(client kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + assert.Equal(t, patchtypes.StrategicMergePatchType, patchType) + assert.NotEmpty(t, bytes) + return nil + } + upgradeFuncs.UpdateFunc = func(kube.Clients, string, runtime.Object) error { + t.Errorf("Update should not be called") + return nil + } + if err != nil { + t.Errorf("Rolling upgrade failed for %s with %s", upgradeFuncs.ResourceType, envVarPostfix) + } +} + func TestRollingUpgradeForDeploymentWithConfigmapUsingArs(t *testing.T) { options.ReloadStrategy = constants.AnnotationsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -1444,6 +1463,47 @@ func TestRollingUpgradeForDeploymentWithConfigmapUsingArs(t *testing.T) { testRollingUpgradeInvokeDeleteStrategyArs(t, clients, config, deploymentFuncs, collectors, envVarPostfix) } +func TestRollingUpgradeForDeploymentWithPatchAndRetryUsingArs(t *testing.T) { + options.ReloadStrategy = constants.AnnotationsReloadStrategy + envVarPostfix := constants.ConfigmapEnvVarPostfix + + shaData := testutil.ConvertResourceToSHA(testutil.ConfigmapResourceType, arsNamespace, arsConfigmapName, "www.stakater.com") + config := getConfigWithAnnotations(envVarPostfix, arsConfigmapName, shaData, options.ConfigmapUpdateOnChangeAnnotation, options.ConfigmapReloaderAutoAnnotation) + deploymentFuncs := GetDeploymentRollingUpgradeFuncs() + + assert.True(t, deploymentFuncs.SupportsPatch) + assert.NotEmpty(t, deploymentFuncs.PatchTemplatesFunc().AnnotationTemplate) + + patchCalled := 0 + deploymentFuncs.PatchFunc = func(client kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + patchCalled++ + if patchCalled < 2 { + return &errors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonConflict}} // simulate conflict + } + assert.Equal(t, patchtypes.StrategicMergePatchType, patchType) + assert.NotEmpty(t, bytes) + assert.Contains(t, string(bytes), `{"spec":{"template":{"metadata":{"annotations":{"reloader.stakater.com/last-reloaded-from":`) + assert.Contains(t, string(bytes), `\"hash\":\"3c9a892aeaedc759abc3df9884a37b8be5680382\"`) + return nil + } + + deploymentFuncs.UpdateFunc = func(kube.Clients, string, runtime.Object) error { + t.Errorf("Update should not be called") + return nil + } + + collectors := getCollectors() + err := PerformAction(clients, config, deploymentFuncs, collectors, nil, invokeReloadStrategy) + if err != nil { + t.Errorf("Rolling upgrade failed for Deployment with Configmap") + } + + assert.Equal(t, 2, patchCalled) + + deploymentFuncs = GetDeploymentRollingUpgradeFuncs() + testRollingUpgradeWithPatchAndInvokeDeleteStrategyArs(t, clients, config, deploymentFuncs, collectors, envVarPostfix) +} + func TestRollingUpgradeForDeploymentWithConfigmapWithoutReloadAnnotationAndWithoutAutoReloadAllNoTriggersUsingArs(t *testing.T) { options.ReloadStrategy = constants.AnnotationsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -1616,7 +1676,7 @@ func TestRollingUpgradeForDeploymentWithConfigmapViaSearchAnnotationNotMappedUsi t.Errorf("Failed to create deployment with search annotation.") } defer func() { - _ = clients.KubernetesClient.AppsV1().Deployments(arsNamespace).Delete(context.TODO(), deployment.Name, v1.DeleteOptions{}) + _ = clients.KubernetesClient.AppsV1().Deployments(arsNamespace).Delete(context.TODO(), deployment.Name, metav1.DeleteOptions{}) }() // defer clients.KubernetesClient.AppsV1().Deployments(namespace).Delete(deployment.Name, &v1.DeleteOptions{}) @@ -2102,6 +2162,7 @@ func TestRollingUpgradeForDeploymentWithExcludeConfigMapAnnotationUsingArs(t *te t.Errorf("Deployment which had to be excluded was updated") } } + func TestRollingUpgradeForDeploymentWithConfigMapAutoAnnotationUsingArs(t *testing.T) { options.ReloadStrategy = constants.AnnotationsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -2166,6 +2227,48 @@ func TestRollingUpgradeForDaemonSetWithConfigmapUsingArs(t *testing.T) { testRollingUpgradeInvokeDeleteStrategyArs(t, clients, config, daemonSetFuncs, collectors, envVarPostfix) } +func TestRollingUpgradeForDaemonSetWithPatchAndRetryUsingArs(t *testing.T) { + options.ReloadStrategy = constants.AnnotationsReloadStrategy + envVarPostfix := constants.ConfigmapEnvVarPostfix + + shaData := testutil.ConvertResourceToSHA(testutil.ConfigmapResourceType, arsNamespace, arsConfigmapName, "www.facebook.com") + config := getConfigWithAnnotations(envVarPostfix, arsConfigmapName, shaData, options.ConfigmapUpdateOnChangeAnnotation, options.ConfigmapReloaderAutoAnnotation) + daemonSetFuncs := GetDaemonSetRollingUpgradeFuncs() + + assert.True(t, daemonSetFuncs.SupportsPatch) + assert.NotEmpty(t, daemonSetFuncs.PatchTemplatesFunc().AnnotationTemplate) + + patchCalled := 0 + daemonSetFuncs.PatchFunc = func(client kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + patchCalled++ + if patchCalled < 2 { + return &errors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonConflict}} // simulate conflict + } + assert.Equal(t, patchtypes.StrategicMergePatchType, patchType) + assert.NotEmpty(t, bytes) + assert.Contains(t, string(bytes), `{"spec":{"template":{"metadata":{"annotations":{"reloader.stakater.com/last-reloaded-from":`) + assert.Contains(t, string(bytes), `\"hash\":\"314a2269170750a974d79f02b5b9ee517de7f280\"`) + return nil + } + + daemonSetFuncs.UpdateFunc = func(kube.Clients, string, runtime.Object) error { + t.Errorf("Update should not be called") + return nil + } + + collectors := getCollectors() + + err := PerformAction(clients, config, daemonSetFuncs, collectors, nil, invokeReloadStrategy) + if err != nil { + t.Errorf("Rolling upgrade failed for DaemonSet with configmap") + } + + assert.Equal(t, 2, patchCalled) + + daemonSetFuncs = GetDeploymentRollingUpgradeFuncs() + testRollingUpgradeWithPatchAndInvokeDeleteStrategyArs(t, clients, config, daemonSetFuncs, collectors, envVarPostfix) +} + func TestRollingUpgradeForDaemonSetWithConfigmapInProjectedVolumeUsingArs(t *testing.T) { options.ReloadStrategy = constants.AnnotationsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -2326,6 +2429,48 @@ func TestRollingUpgradeForStatefulSetWithConfigmapUsingArs(t *testing.T) { testRollingUpgradeInvokeDeleteStrategyArs(t, clients, config, statefulSetFuncs, collectors, envVarPostfix) } +func TestRollingUpgradeForStatefulSetWithPatchAndRetryUsingArs(t *testing.T) { + options.ReloadStrategy = constants.AnnotationsReloadStrategy + envVarPostfix := constants.ConfigmapEnvVarPostfix + + shaData := testutil.ConvertResourceToSHA(testutil.ConfigmapResourceType, arsNamespace, arsConfigmapName, "www.twitter.com") + config := getConfigWithAnnotations(envVarPostfix, arsConfigmapName, shaData, options.ConfigmapUpdateOnChangeAnnotation, options.ConfigmapReloaderAutoAnnotation) + statefulSetFuncs := GetStatefulSetRollingUpgradeFuncs() + + assert.True(t, statefulSetFuncs.SupportsPatch) + assert.NotEmpty(t, statefulSetFuncs.PatchTemplatesFunc().AnnotationTemplate) + + patchCalled := 0 + statefulSetFuncs.PatchFunc = func(client kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + patchCalled++ + if patchCalled < 2 { + return &errors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonConflict}} // simulate conflict + } + assert.Equal(t, patchtypes.StrategicMergePatchType, patchType) + assert.NotEmpty(t, bytes) + assert.Contains(t, string(bytes), `{"spec":{"template":{"metadata":{"annotations":{"reloader.stakater.com/last-reloaded-from":`) + assert.Contains(t, string(bytes), `\"hash\":\"f821414d40d8815fb330763f74a4ff7ab651d4fa\"`) + return nil + } + + statefulSetFuncs.UpdateFunc = func(kube.Clients, string, runtime.Object) error { + t.Errorf("Update should not be called") + return nil + } + + collectors := getCollectors() + + err := PerformAction(clients, config, statefulSetFuncs, collectors, nil, invokeReloadStrategy) + if err != nil { + t.Errorf("Rolling upgrade failed for StatefulSet with configmap") + } + + assert.Equal(t, 2, patchCalled) + + statefulSetFuncs = GetDeploymentRollingUpgradeFuncs() + testRollingUpgradeWithPatchAndInvokeDeleteStrategyArs(t, clients, config, statefulSetFuncs, collectors, envVarPostfix) +} + func TestRollingUpgradeForStatefulSetWithConfigmapInProjectedVolumeUsingArs(t *testing.T) { options.ReloadStrategy = constants.AnnotationsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -2488,6 +2633,9 @@ func TestFailedRollingUpgradeUsingArs(t *testing.T) { deploymentFuncs.UpdateFunc = func(_ kube.Clients, _ string, _ runtime.Object) error { return fmt.Errorf("error") } + deploymentFuncs.PatchFunc = func(kube.Clients, string, runtime.Object, patchtypes.PatchType, []byte) error { + return fmt.Errorf("error") + } collectors := getCollectors() _ = PerformAction(clients, config, deploymentFuncs, collectors, nil, invokeReloadStrategy) @@ -2518,6 +2666,24 @@ func testRollingUpgradeInvokeDeleteStrategyErs(t *testing.T, clients kube.Client } } +func testRollingUpgradeWithPatchAndInvokeDeleteStrategyErs(t *testing.T, clients kube.Clients, config util.Config, upgradeFuncs callbacks.RollingUpgradeFuncs, collectors metrics.Collectors, envVarPostfix string) { + assert.NotEmpty(t, upgradeFuncs.PatchTemplatesFunc().DeleteEnvVarTemplate) + + err := PerformAction(clients, config, upgradeFuncs, collectors, nil, invokeDeleteStrategy) + upgradeFuncs.PatchFunc = func(client kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + assert.Equal(t, patchtypes.JSONPatchType, patchType) + assert.NotEmpty(t, bytes) + return nil + } + upgradeFuncs.UpdateFunc = func(kube.Clients, string, runtime.Object) error { + t.Errorf("Update should not be called") + return nil + } + if err != nil { + t.Errorf("Rolling upgrade failed for %s with %s", upgradeFuncs.ResourceType, envVarPostfix) + } +} + func TestRollingUpgradeForDeploymentWithConfigmapUsingErs(t *testing.T) { options.ReloadStrategy = constants.EnvVarsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -2550,6 +2716,48 @@ func TestRollingUpgradeForDeploymentWithConfigmapUsingErs(t *testing.T) { testRollingUpgradeInvokeDeleteStrategyErs(t, clients, config, deploymentFuncs, collectors, envVarPostfix) } +func TestRollingUpgradeForDeploymentWithPatchAndRetryUsingErs(t *testing.T) { + options.ReloadStrategy = constants.EnvVarsReloadStrategy + envVarPostfix := constants.ConfigmapEnvVarPostfix + + shaData := testutil.ConvertResourceToSHA(testutil.ConfigmapResourceType, ersNamespace, ersConfigmapName, "www.stakater.com") + config := getConfigWithAnnotations(envVarPostfix, ersConfigmapName, shaData, options.ConfigmapUpdateOnChangeAnnotation, options.ConfigmapReloaderAutoAnnotation) + deploymentFuncs := GetDeploymentRollingUpgradeFuncs() + + assert.True(t, deploymentFuncs.SupportsPatch) + assert.NotEmpty(t, deploymentFuncs.PatchTemplatesFunc().EnvVarTemplate) + + patchCalled := 0 + deploymentFuncs.PatchFunc = func(client kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + patchCalled++ + if patchCalled < 2 { + return &errors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonConflict}} // simulate conflict + } + assert.Equal(t, patchtypes.StrategicMergePatchType, patchType) + assert.NotEmpty(t, bytes) + assert.Contains(t, string(bytes), `{"spec":{"template":{"spec":{"containers":[{"name":`) + assert.Contains(t, string(bytes), `"value":"3c9a892aeaedc759abc3df9884a37b8be5680382"`) + return nil + } + + deploymentFuncs.UpdateFunc = func(kube.Clients, string, runtime.Object) error { + t.Errorf("Update should not be called") + return nil + } + + collectors := getCollectors() + + err := PerformAction(clients, config, deploymentFuncs, collectors, nil, invokeReloadStrategy) + if err != nil { + t.Errorf("Rolling upgrade failed for %s with %s", deploymentFuncs.ResourceType, envVarPostfix) + } + + assert.Equal(t, 2, patchCalled) + + deploymentFuncs = GetDeploymentRollingUpgradeFuncs() + testRollingUpgradeWithPatchAndInvokeDeleteStrategyErs(t, clients, config, deploymentFuncs, collectors, envVarPostfix) +} + func TestRollingUpgradeForDeploymentWithConfigmapInProjectedVolumeUsingErs(t *testing.T) { options.ReloadStrategy = constants.EnvVarsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -2658,7 +2866,7 @@ func TestRollingUpgradeForDeploymentWithConfigmapViaSearchAnnotationNotMappedUsi t.Errorf("Failed to create deployment with search annotation.") } defer func() { - _ = clients.KubernetesClient.AppsV1().Deployments(ersNamespace).Delete(context.TODO(), deployment.Name, v1.DeleteOptions{}) + _ = clients.KubernetesClient.AppsV1().Deployments(ersNamespace).Delete(context.TODO(), deployment.Name, metav1.DeleteOptions{}) }() // defer clients.KubernetesClient.AppsV1().Deployments(namespace).Delete(deployment.Name, &v1.DeleteOptions{}) @@ -3212,6 +3420,49 @@ func TestRollingUpgradeForDaemonSetWithConfigmapUsingErs(t *testing.T) { testRollingUpgradeInvokeDeleteStrategyErs(t, clients, config, daemonSetFuncs, collectors, envVarPostfix) } +func TestRollingUpgradeForDaemonSetWithPatchAndRetryUsingErs(t *testing.T) { + options.ReloadStrategy = constants.EnvVarsReloadStrategy + envVarPostfix := constants.ConfigmapEnvVarPostfix + + shaData := testutil.ConvertResourceToSHA(testutil.ConfigmapResourceType, ersNamespace, ersConfigmapName, "www.facebook.com") + config := getConfigWithAnnotations(envVarPostfix, ersConfigmapName, shaData, options.ConfigmapUpdateOnChangeAnnotation, options.ConfigmapReloaderAutoAnnotation) + daemonSetFuncs := GetDaemonSetRollingUpgradeFuncs() + + assert.True(t, daemonSetFuncs.SupportsPatch) + assert.NotEmpty(t, daemonSetFuncs.PatchTemplatesFunc().EnvVarTemplate) + + patchCalled := 0 + daemonSetFuncs.PatchFunc = func(client kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + patchCalled++ + if patchCalled < 2 { + return &errors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonConflict}} // simulate conflict + } + assert.Equal(t, patchtypes.StrategicMergePatchType, patchType) + assert.NotEmpty(t, bytes) + assert.Contains(t, string(bytes), `{"spec":{"template":{"spec":{"containers":[{"name":`) + assert.Contains(t, string(bytes), `"value":"314a2269170750a974d79f02b5b9ee517de7f280"`) + return nil + } + + daemonSetFuncs.UpdateFunc = func(kube.Clients, string, runtime.Object) error { + t.Errorf("Update should not be called") + return nil + } + + collectors := getCollectors() + + err := PerformAction(clients, config, daemonSetFuncs, collectors, nil, invokeReloadStrategy) + time.Sleep(5 * time.Second) + if err != nil { + t.Errorf("Rolling upgrade failed for DaemonSet with configmap") + } + + assert.Equal(t, 2, patchCalled) + + daemonSetFuncs = GetDeploymentRollingUpgradeFuncs() + testRollingUpgradeWithPatchAndInvokeDeleteStrategyErs(t, clients, config, daemonSetFuncs, collectors, envVarPostfix) +} + func TestRollingUpgradeForDaemonSetWithConfigmapInProjectedVolumeUsingErs(t *testing.T) { options.ReloadStrategy = constants.EnvVarsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -3372,6 +3623,49 @@ func TestRollingUpgradeForStatefulSetWithConfigmapUsingErs(t *testing.T) { testRollingUpgradeInvokeDeleteStrategyErs(t, clients, config, statefulSetFuncs, collectors, envVarPostfix) } +func TestRollingUpgradeForStatefulSetWithPatchAndRetryUsingErs(t *testing.T) { + options.ReloadStrategy = constants.EnvVarsReloadStrategy + envVarPostfix := constants.ConfigmapEnvVarPostfix + + shaData := testutil.ConvertResourceToSHA(testutil.ConfigmapResourceType, ersNamespace, ersConfigmapName, "www.twitter.com") + config := getConfigWithAnnotations(envVarPostfix, ersConfigmapName, shaData, options.ConfigmapUpdateOnChangeAnnotation, options.ConfigmapReloaderAutoAnnotation) + statefulSetFuncs := GetStatefulSetRollingUpgradeFuncs() + + assert.True(t, statefulSetFuncs.SupportsPatch) + assert.NotEmpty(t, statefulSetFuncs.PatchTemplatesFunc().EnvVarTemplate) + + patchCalled := 0 + statefulSetFuncs.PatchFunc = func(client kube.Clients, namespace string, resource runtime.Object, patchType patchtypes.PatchType, bytes []byte) error { + patchCalled++ + if patchCalled < 2 { + return &errors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonConflict}} // simulate conflict + } + assert.Equal(t, patchtypes.StrategicMergePatchType, patchType) + assert.NotEmpty(t, bytes) + assert.Contains(t, string(bytes), `{"spec":{"template":{"spec":{"containers":[{"name":`) + assert.Contains(t, string(bytes), `"value":"f821414d40d8815fb330763f74a4ff7ab651d4fa"`) + return nil + } + + statefulSetFuncs.UpdateFunc = func(kube.Clients, string, runtime.Object) error { + t.Errorf("Update should not be called") + return nil + } + + collectors := getCollectors() + + err := PerformAction(clients, config, statefulSetFuncs, collectors, nil, invokeReloadStrategy) + time.Sleep(5 * time.Second) + if err != nil { + t.Errorf("Rolling upgrade failed for StatefulSet with configmap") + } + + assert.Equal(t, 2, patchCalled) + + statefulSetFuncs = GetDeploymentRollingUpgradeFuncs() + testRollingUpgradeWithPatchAndInvokeDeleteStrategyErs(t, clients, config, statefulSetFuncs, collectors, envVarPostfix) +} + func TestRollingUpgradeForStatefulSetWithConfigmapInProjectedVolumeUsingErs(t *testing.T) { options.ReloadStrategy = constants.EnvVarsReloadStrategy envVarPostfix := constants.ConfigmapEnvVarPostfix @@ -3536,6 +3830,9 @@ func TestFailedRollingUpgradeUsingErs(t *testing.T) { deploymentFuncs.UpdateFunc = func(_ kube.Clients, _ string, _ runtime.Object) error { return fmt.Errorf("error") } + deploymentFuncs.PatchFunc = func(kube.Clients, string, runtime.Object, patchtypes.PatchType, []byte) error { + return fmt.Errorf("error") + } collectors := getCollectors() _ = PerformAction(clients, config, deploymentFuncs, collectors, nil, invokeReloadStrategy) diff --git a/internal/pkg/testutil/kube.go b/internal/pkg/testutil/kube.go index 1f779ab..f2d3bb4 100644 --- a/internal/pkg/testutil/kube.go +++ b/internal/pkg/testutil/kube.go @@ -968,6 +968,22 @@ func DeleteStatefulSet(client kubernetes.Interface, namespace string, statefulse return statefulsetError } +// DeleteCronJob deletes a cronJob in given namespace and returns the error if any +func DeleteCronJob(client kubernetes.Interface, namespace string, cronJobName string) error { + logrus.Infof("Deleting CronJob %s", cronJobName) + cronJobError := client.BatchV1().CronJobs(namespace).Delete(context.TODO(), cronJobName, metav1.DeleteOptions{}) + time.Sleep(3 * time.Second) + return cronJobError +} + +// Deleteob deletes a job in given namespace and returns the error if any +func DeleteJob(client kubernetes.Interface, namespace string, jobName string) error { + logrus.Infof("Deleting Job %s", jobName) + jobError := client.BatchV1().Jobs(namespace).Delete(context.TODO(), jobName, metav1.DeleteOptions{}) + time.Sleep(3 * time.Second) + return jobError +} + // UpdateConfigMap updates a configmap in given namespace and returns the error if any func UpdateConfigMap(configmapClient core_v1.ConfigMapInterface, namespace string, configmapName string, label string, data string) error { logrus.Infof("Updating configmap %q.\n", configmapName)