diff --git a/kube-hunter.py b/kube-hunter.py index ee3bcd0..e8cedf4 100755 --- a/kube-hunter.py +++ b/kube-hunter.py @@ -82,7 +82,7 @@ def list_hunters(): print("\nPassive Hunters:\n----------------") for i, (hunter, docs) in enumerate(handler.passive_hunters.items()): name, docs = parse_docs(hunter, docs) - print("* {}\n {}\n".format( name, docs)) + print("* {}\n {}\n".format(name, docs)) if config.active: print("\n\nActive Hunters:\n---------------") @@ -91,11 +91,13 @@ def list_hunters(): print("* {}\n {}\n".format( name, docs)) +global hunt_started_lock +hunt_started_lock = threading.Lock() hunt_started = False def main(): - global hunt_started + global hunt_started scan_options = [ config.pod, config.cidr, @@ -109,7 +111,10 @@ def main(): if not any(scan_options): if not interactive_set_config(): return + + hunt_started_lock.acquire() hunt_started = True + hunt_started_lock.release() handler.publish_event(HuntStarted()) handler.publish_event(HostScanEvent()) @@ -121,11 +126,16 @@ def main(): except EOFError: logging.error("\033[0;31mPlease run again with -it\033[0m") finally: + hunt_started_lock.acquire() if hunt_started: + hunt_started_lock.release() handler.publish_event(HuntFinished()) handler.join() handler.free() logging.debug("Cleaned Queue") + else: + hunt_started_lock.release() + if __name__ == '__main__': diff --git a/src/core/events/handler.py b/src/core/events/handler.py index 916aeff..dd6e6bc 100644 --- a/src/core/events/handler.py +++ b/src/core/events/handler.py @@ -12,7 +12,8 @@ from ..types import ActiveHunter, Hunter from ...core.events.types import HuntFinished import threading -working_count = 0 +global queue_lock +queue_lock = Lock() # Inherits Queue object, handles events asynchronously class EventQueue(Queue, object): @@ -34,12 +35,12 @@ class EventQueue(Queue, object): t.daemon = True t.start() - # decorator wrapping for easy subscription def subscribe(self, event, hook=None, predicate=None): def wrapper(hook): self.subscribe_event(event, hook=hook, predicate=predicate) return hook + return wrapper # getting uninstantiated event object @@ -72,7 +73,9 @@ class EventQueue(Queue, object): # executes callbacks on dedicated thread as a daemon def worker(self): while self.running: + queue_lock.acquire() hook = self.get() + queue_lock.release() try: hook.execute() except Exception as ex: diff --git a/src/core/events/types/common.py b/src/core/events/types/common.py index 62590e9..ffe60a4 100644 --- a/src/core/events/types/common.py +++ b/src/core/events/types/common.py @@ -65,6 +65,9 @@ class Vulnerability(object): def explain(self): return self.__doc__ + +global event_id_count_lock +event_id_count_lock = threading.Lock() event_id_count = 0 """ Discovery/Hunting Events """ @@ -75,8 +78,10 @@ class NewHostEvent(Event): global event_id_count self.host = host self.cloud = cloud + event_id_count_lock.acquire() self.event_id = event_id_count event_id_count += 1 + event_id_count_lock.release() def __str__(self): return str(self.host) diff --git a/src/modules/discovery/hosts.py b/src/modules/discovery/hosts.py index f4d5e1d..9a87766 100644 --- a/src/modules/discovery/hosts.py +++ b/src/modules/discovery/hosts.py @@ -53,7 +53,7 @@ class HostDiscovery(Hunter): def __init__(self, event): self.event = event - def execute(self): + def execute(self): if config.cidr: try: ip, sn = config.cidr.split('/') diff --git a/src/modules/report/collector.py b/src/modules/report/collector.py index db6b77f..b23e454 100644 --- a/src/modules/report/collector.py +++ b/src/modules/report/collector.py @@ -5,7 +5,13 @@ from src.core.events import handler from src.core.events.types import Event, Service, Vulnerability, HuntFinished, HuntStarted import threading + +global services_lock +services_lock = threading.Lock() services = list() + +global vulnerabilities_lock +vulnerabilities_lock = threading.Lock() vulnerabilities = list() @@ -38,10 +44,13 @@ class Collector(object): def execute(self): """function is called only when collecting data""" - global services, vulnerabilities + global services + global vulnerabilities bases = self.event.__class__.__mro__ if Service in bases: + services_lock.acquire() services.append(self.event) + services_lock.release() import datetime logging.info("|\n| {name}:\n| type: open service\n| service: {name}\n|_ host: {host}:{port}".format( host=self.event.host, @@ -51,7 +60,9 @@ class Collector(object): )) elif Vulnerability in bases: + vulnerabilities_lock.acquire() vulnerabilities.append(self.event) + vulnerabilities_lock.release() logging.info( "|\n| {name}:\n| type: vulnerability\n| host: {host}:{port}\n| description: \n{desc}".format( name=self.event.get_name(), diff --git a/src/modules/report/plain.py b/src/modules/report/plain.py index 37e1838..029213d 100644 --- a/src/modules/report/plain.py +++ b/src/modules/report/plain.py @@ -3,8 +3,7 @@ from __future__ import print_function from prettytable import ALL, PrettyTable from __main__ import config -from collector import services, vulnerabilities -import threading +from collector import services, vulnerabilities, services_lock, vulnerabilities_lock EVIDENCE_PREVIEW = 40 MAX_TABLE_WIDTH = 20 @@ -15,11 +14,20 @@ class PlainReporter(object): def get_report(self): """generates report tables""" output = "" - if len(services): + + vulnerabilities_lock.acquire() + vulnerabilities_len = len(services) + vulnerabilities_lock.release() + + services_lock.acquire() + services_len = len(vulnerabilities) + services_lock.release() + + if services_len: output += self.nodes_table() if not config.mapping: output += self.services_table() - if len(vulnerabilities): + if vulnerabilities_len: output += self.vulns_table() else: output += "\nNo vulnerabilities were found" @@ -38,11 +46,14 @@ class PlainReporter(object): nodes_table.header_style = "upper" # TODO: replace with sets id_memory = list() + services_lock.acquire() for service in services: if service.event_id not in id_memory: nodes_table.add_row(["Node/Master", service.host]) id_memory.append(service.event_id) - return "\nNodes\n{}\n".format(nodes_table) + nodes_ret = "\nNodes\n{}\n".format(nodes_table) + services_lock.release() + return nodes_ret def services_table(self): services_table = PrettyTable(["Service", "Location", "Description"], hrules=ALL) @@ -52,9 +63,12 @@ class PlainReporter(object): services_table.sortby = "Service" services_table.reversesort = True services_table.header_style = "upper" + services_lock.acquire() for service in services: services_table.add_row([service.get_name(), "{}:{}{}".format(service.host, service.port, service.get_path()), service.explain()]) - return "\nDetected Services\n{}\n".format(services_table) + detected_services_ret = "\nDetected Services\n{}\n".format(services_table) + services_lock.release() + return detected_services_ret def vulns_table(self): column_names = ["Location", "Category", "Vulnerability", "Description", "Evidence"] @@ -65,9 +79,12 @@ class PlainReporter(object): vuln_table.reversesort = True vuln_table.padding_width = 1 vuln_table.header_style = "upper" + + vulnerabilities_lock.acquire() for vuln in vulnerabilities: row = ["{}:{}".format(vuln.host, vuln.port) if vuln.host else "", vuln.category.name, vuln.get_name(), vuln.explain()] evidence = str(vuln.evidence)[:EVIDENCE_PREVIEW] + "..." if len(str(vuln.evidence)) > EVIDENCE_PREVIEW else str(vuln.evidence) row.append(evidence) vuln_table.add_row(row) + vulnerabilities_lock.release() return "\nVulnerabilities\n{}\n".format(vuln_table) diff --git a/src/modules/report/yaml.py b/src/modules/report/yaml.py index 26a9e83..1ffcf58 100644 --- a/src/modules/report/yaml.py +++ b/src/modules/report/yaml.py @@ -2,8 +2,7 @@ import StringIO from ruamel.yaml import YAML -from collector import services, vulnerabilities -import threading +from collector import services, vulnerabilities, services_lock, vulnerabilities_lock class YAMLReporter(object): def get_report(self): @@ -20,25 +19,31 @@ class YAMLReporter(object): def get_nodes(self): nodes = list() node_locations = set() + services_lock.acquire() for service in services: node_location = str(service.host) if node_location not in node_locations: nodes.append({"type": "Node/Master", "location": str(service.host)}) node_locations.add(node_location) + services_lock.release() return nodes def get_services(self): + services_lock.acquire() services_data = [{"service": service.get_name(), "location": "{}:{}{}".format(service.host, service.port, service.get_path()), "description": service.explain()} for service in services] + services_lock.release() return services_data def get_vulenrabilities(self): + vulnerabilities_lock.acquire() vulnerabilities_data = [{"location": "{}:{}".format(vuln.host, vuln.port) if vuln.host else "", "category": vuln.category.name, "vulnerability": vuln.get_name(), "description": vuln.explain(), "evidence": str(vuln.evidence)} for vuln in vulnerabilities] + vulnerabilities_lock.release() return vulnerabilities_data