feat: add GCP zone outage rollback support (#1200)

Add rollback functionality for GCP zone outage scenarios following the
established rollback pattern (Service Hijacking, PVC, Syn Flood).

- Add @set_rollback_context_decorator to run()
- Set rollback callable before stopping nodes with base64/JSON encoded data
- Add rollback_gcp_zone_outage() static method with per-node error handling
- Fix missing poll_interval argument in starmap calls
- Add unit tests for rollback and run methods

Closes #915

Signed-off-by: YASHASVIYADAV30 <yashasviydv30@gmail.com>
Co-authored-by: Paige Patton <64206430+paigerube14@users.noreply.github.com>
This commit is contained in:
Yashasvi Yadav
2026-03-27 00:12:45 +05:30
committed by GitHub
parent ec241d35d6
commit 62f500fb2e
2 changed files with 314 additions and 19 deletions

View File

@@ -1,3 +1,5 @@
import base64
import json
import logging
import time
@@ -13,11 +15,15 @@ from krkn_lib.telemetry.ocp import KrknTelemetryOpenshift
from krkn.scenario_plugins.abstract_scenario_plugin import AbstractScenarioPlugin
from krkn_lib.utils import get_yaml_item_value
from krkn.rollback.config import RollbackContent
from krkn.rollback.handler import set_rollback_context_decorator
from krkn.scenario_plugins.node_actions.aws_node_scenarios import AWS
from krkn.scenario_plugins.node_actions.gcp_node_scenarios import gcp_node_scenarios
class ZoneOutageScenarioPlugin(AbstractScenarioPlugin):
@set_rollback_context_decorator
def run(
self,
run_uuid: str,
@@ -40,7 +46,9 @@ class ZoneOutageScenarioPlugin(AbstractScenarioPlugin):
if cloud_type.lower() == "gcp":
affected_nodes_status = AffectedNodeStatus()
self.cloud_object = gcp_node_scenarios(kubecli, kube_check, affected_nodes_status)
self.node_based_zone(scenario_config, kubecli)
result = self.node_based_zone(scenario_config, kubecli)
if result != 0:
return result
affected_nodes_status = self.cloud_object.affected_nodes_status
scenario_telemetry.affected_nodes.extend(affected_nodes_status.affected_nodes)
else:
@@ -57,22 +65,37 @@ class ZoneOutageScenarioPlugin(AbstractScenarioPlugin):
return 1
else:
return 0
def node_based_zone(self, scenario_config: dict[str, any], kubecli: KrknKubernetes ):
def node_based_zone(self, scenario_config: dict[str, any], kubecli: KrknKubernetes):
zone = scenario_config["zone"]
duration = get_yaml_item_value(scenario_config, "duration", 60)
timeout = get_yaml_item_value(scenario_config, "timeout", 180)
kube_check = get_yaml_item_value(scenario_config, "kube_check", True)
label_selector = f"topology.kubernetes.io/zone={zone}"
try:
try:
# get list of nodes in zone/region
nodes = kubecli.list_killable_nodes(label_selector)
# stop nodes in parallel
pool = ThreadPool(processes=len(nodes))
pool.starmap(
self.cloud_object.node_stop_scenario,zip(repeat(1), nodes, repeat(timeout))
# set rollback callable before stopping nodes
rollback_data = {
"nodes": nodes,
"timeout": timeout,
"kube_check": kube_check,
}
encoded = base64.b64encode(
json.dumps(rollback_data).encode("utf-8")
).decode("utf-8")
self.rollback_handler.set_rollback_callable(
self.rollback_gcp_zone_outage,
RollbackContent(resource_identifier=encoded),
)
# stop nodes in parallel
pool = ThreadPool(processes=len(nodes))
pool.starmap(
self.cloud_object.node_stop_scenario,
zip(repeat(1), nodes, repeat(timeout), repeat(None)),
)
pool.close()
logging.info(
@@ -80,10 +103,11 @@ class ZoneOutageScenarioPlugin(AbstractScenarioPlugin):
)
time.sleep(duration)
# start nodes in parallel
# start nodes in parallel
pool = ThreadPool(processes=len(nodes))
pool.starmap(
self.cloud_object.node_start_scenario,zip(repeat(1), nodes, repeat(timeout))
self.cloud_object.node_start_scenario,
zip(repeat(1), nodes, repeat(timeout), repeat(None)),
)
pool.close()
except Exception as e:
@@ -94,6 +118,58 @@ class ZoneOutageScenarioPlugin(AbstractScenarioPlugin):
else:
return 0
@staticmethod
def rollback_gcp_zone_outage(
rollback_content: RollbackContent,
lib_telemetry: KrknTelemetryOpenshift,
):
"""Rollback function to restart stopped nodes after a GCP zone outage
scenario failure.
:param rollback_content: Rollback content containing encoded node
list and config.
:param lib_telemetry: Instance of KrknTelemetryOpenshift for
Kubernetes operations.
"""
try:
import json
import base64
from krkn_lib.models.k8s import AffectedNodeStatus
from krkn.scenario_plugins.node_actions.gcp_node_scenarios import (
gcp_node_scenarios,
)
decoded = base64.b64decode(
rollback_content.resource_identifier.encode("utf-8")
).decode("utf-8")
rollback_data = json.loads(decoded)
nodes = rollback_data["nodes"]
timeout = rollback_data["timeout"]
kube_check = rollback_data["kube_check"]
kubecli = lib_telemetry.get_lib_kubernetes()
affected_nodes_status = AffectedNodeStatus()
cloud_object = gcp_node_scenarios(
kubecli, kube_check, affected_nodes_status
)
logging.info(
"Rolling back GCP zone outage: starting %d stopped nodes"
% len(nodes)
)
for node in nodes:
try:
cloud_object.node_start_scenario(1, node, timeout, None)
except Exception as node_error:
logging.error(
"Failed to start node %s during rollback: %s"
% (node, node_error)
)
logging.info("GCP zone outage rollback completed.")
except Exception as e:
logging.error("Failed to rollback GCP zone outage: %s" % e)
raise
def network_based_zone(self, scenario_config: dict[str, any]):
vpc_id = scenario_config["vpc_id"]
@@ -118,12 +194,12 @@ class ZoneOutageScenarioPlugin(AbstractScenarioPlugin):
"Network association ids associated with "
"the subnet %s: %s" % (subnet_id, network_association_ids)
)
# Use provided default ACL if available, otherwise create a new one
if default_acl_id:
acl_id = default_acl_id
logging.info(
"Using provided default ACL ID %s - this ACL will not be deleted after the scenario",
"Using provided default ACL ID %s - this ACL will not be deleted after the scenario",
default_acl_id
)
# Don't add to acl_ids_created since we don't want to delete user-provided ACLs at cleanup
@@ -160,6 +236,5 @@ class ZoneOutageScenarioPlugin(AbstractScenarioPlugin):
for acl_id in acl_ids_created:
self.cloud_object.delete_network_acl(acl_id)
def get_scenario_types(self) -> list[str]:
return ["zone_outages_scenarios"]

View File

@@ -4,18 +4,26 @@
Test suite for ZoneOutageScenarioPlugin class
Usage:
python -m coverage run -a -m unittest tests/test_zone_outage_scenario_plugin.py -v
python -m coverage run -a -m unittest \
tests/test_zone_outage_scenario_plugin.py -v
Assisted By: Claude Code
"""
import base64
import json
import tempfile
import unittest
from unittest.mock import MagicMock
import uuid
from pathlib import Path
from unittest.mock import MagicMock, patch
from krkn_lib.k8s import KrknKubernetes
from krkn_lib.telemetry.ocp import KrknTelemetryOpenshift
import yaml
from krkn.scenario_plugins.zone_outage.zone_outage_scenario_plugin import ZoneOutageScenarioPlugin
from krkn.rollback.config import RollbackContent
from krkn.scenario_plugins.zone_outage.zone_outage_scenario_plugin import (
ZoneOutageScenarioPlugin,
)
class TestZoneOutageScenarioPlugin(unittest.TestCase):
@@ -36,5 +44,217 @@ class TestZoneOutageScenarioPlugin(unittest.TestCase):
self.assertEqual(len(result), 1)
class TestRollbackGcpZoneOutage(unittest.TestCase):
"""Tests for the GCP zone outage rollback functionality"""
@patch(
"krkn.scenario_plugins.node_actions."
"gcp_node_scenarios.gcp_node_scenarios"
)
def test_rollback_gcp_zone_outage_success(self, mock_gcp_class):
"""
Test successful rollback starts all stopped nodes
"""
rollback_data = {
"nodes": ["node-1", "node-2", "node-3"],
"timeout": 180,
"kube_check": True,
}
encoded = base64.b64encode(
json.dumps(rollback_data).encode("utf-8")
).decode("utf-8")
rollback_content = RollbackContent(
resource_identifier=encoded,
)
mock_lib_telemetry = MagicMock()
mock_kubecli = MagicMock()
mock_lib_telemetry.get_lib_kubernetes.return_value = mock_kubecli
mock_cloud_instance = MagicMock()
mock_gcp_class.return_value = mock_cloud_instance
ZoneOutageScenarioPlugin.rollback_gcp_zone_outage(
rollback_content, mock_lib_telemetry
)
self.assertEqual(
mock_cloud_instance.node_start_scenario.call_count, 3
)
mock_cloud_instance.node_start_scenario.assert_any_call(
1, "node-1", 180, None
)
mock_cloud_instance.node_start_scenario.assert_any_call(
1, "node-2", 180, None
)
mock_cloud_instance.node_start_scenario.assert_any_call(
1, "node-3", 180, None
)
@patch(
"krkn.scenario_plugins.node_actions."
"gcp_node_scenarios.gcp_node_scenarios"
)
def test_rollback_gcp_zone_outage_partial_failure(self, mock_gcp_class):
"""
Test rollback continues when one node fails to start
"""
rollback_data = {
"nodes": ["node-1", "node-2"],
"timeout": 180,
"kube_check": True,
}
encoded = base64.b64encode(
json.dumps(rollback_data).encode("utf-8")
).decode("utf-8")
rollback_content = RollbackContent(
resource_identifier=encoded,
)
mock_lib_telemetry = MagicMock()
mock_kubecli = MagicMock()
mock_lib_telemetry.get_lib_kubernetes.return_value = mock_kubecli
mock_cloud_instance = MagicMock()
mock_gcp_class.return_value = mock_cloud_instance
mock_cloud_instance.node_start_scenario.side_effect = [
Exception("GCP API error"),
None,
]
ZoneOutageScenarioPlugin.rollback_gcp_zone_outage(
rollback_content, mock_lib_telemetry
)
self.assertEqual(
mock_cloud_instance.node_start_scenario.call_count, 2
)
def test_rollback_gcp_zone_outage_invalid_data(self):
"""
Test rollback raises exception for invalid base64 data
"""
rollback_content = RollbackContent(
resource_identifier="invalid_base64_data",
)
mock_lib_telemetry = MagicMock()
with self.assertRaises(Exception):
ZoneOutageScenarioPlugin.rollback_gcp_zone_outage(
rollback_content, mock_lib_telemetry
)
class TestZoneOutageRun(unittest.TestCase):
"""Tests for the run method of ZoneOutageScenarioPlugin"""
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.tmp_path = Path(self.temp_dir.name)
def tearDown(self):
self.temp_dir.cleanup()
def _create_scenario_file(self, config=None):
"""Helper to create a temporary scenario YAML file"""
default_config = {
"zone_outage": {
"cloud_type": "gcp",
"zone": "us-central1-a",
"duration": 1,
"timeout": 10,
"kube_check": True,
}
}
if config:
default_config["zone_outage"].update(config)
scenario_file = self.tmp_path / "test_scenario.yaml"
with open(scenario_file, "w") as f:
yaml.dump(default_config, f)
return str(scenario_file)
def _create_mocks(self):
"""Helper to create mock objects for testing"""
mock_lib_telemetry = MagicMock()
mock_lib_kubernetes = MagicMock()
mock_lib_telemetry.get_lib_kubernetes.return_value = (
mock_lib_kubernetes
)
mock_scenario_telemetry = MagicMock()
return mock_lib_telemetry, mock_lib_kubernetes, mock_scenario_telemetry
@patch("time.sleep")
@patch(
"krkn.scenario_plugins.zone_outage."
"zone_outage_scenario_plugin.gcp_node_scenarios"
)
def test_run_gcp_success(self, mock_gcp_class, mock_sleep):
"""Test successful GCP zone outage scenario execution"""
scenario_file = self._create_scenario_file()
mock_lib_telemetry, mock_lib_kubernetes, mock_scenario_telemetry = (
self._create_mocks()
)
mock_lib_kubernetes.list_killable_nodes.return_value = ["node-1"]
mock_cloud = MagicMock()
mock_gcp_class.return_value = mock_cloud
plugin = ZoneOutageScenarioPlugin()
result = plugin.run(
run_uuid=str(uuid.uuid4()),
scenario=scenario_file,
lib_telemetry=mock_lib_telemetry,
scenario_telemetry=mock_scenario_telemetry,
)
self.assertEqual(result, 0)
mock_lib_kubernetes.list_killable_nodes.assert_called_once()
mock_cloud.node_stop_scenario.assert_called()
mock_cloud.node_start_scenario.assert_called()
def test_run_unsupported_cloud_type(self):
"""Test run returns 1 for unsupported cloud type"""
scenario_file = self._create_scenario_file(
{"cloud_type": "unsupported"}
)
mock_lib_telemetry, mock_lib_kubernetes, mock_scenario_telemetry = (
self._create_mocks()
)
plugin = ZoneOutageScenarioPlugin()
result = plugin.run(
run_uuid=str(uuid.uuid4()),
scenario=scenario_file,
lib_telemetry=mock_lib_telemetry,
scenario_telemetry=mock_scenario_telemetry,
)
self.assertEqual(result, 1)
def test_run_gcp_exception(self):
"""Test run handles exceptions gracefully"""
scenario_file = self._create_scenario_file()
mock_lib_telemetry, mock_lib_kubernetes, mock_scenario_telemetry = (
self._create_mocks()
)
mock_lib_telemetry.get_lib_kubernetes.side_effect = Exception(
"Connection error"
)
plugin = ZoneOutageScenarioPlugin()
result = plugin.run(
run_uuid=str(uuid.uuid4()),
scenario=scenario_file,
lib_telemetry=mock_lib_telemetry,
scenario_telemetry=mock_scenario_telemetry,
)
self.assertEqual(result, 1)
if __name__ == "__main__":
unittest.main()