diff --git a/examples/sdk_examples/records/list_records.py b/examples/sdk_examples/records/list_records.py index 9e3ccad0..6d45de3b 100644 --- a/examples/sdk_examples/records/list_records.py +++ b/examples/sdk_examples/records/list_records.py @@ -452,4 +452,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/keepercli-package/src/keepercli/biometric/commands/update_name.py b/keepercli-package/src/keepercli/biometric/commands/update_name.py index c353dd46..db674387 100644 --- a/keepercli-package/src/keepercli/biometric/commands/update_name.py +++ b/keepercli-package/src/keepercli/biometric/commands/update_name.py @@ -23,9 +23,6 @@ def __init__(self): parser = argparse.ArgumentParser(prog='biometric update-name', description='Update friendly name of a biometric passkey') super().__init__(parser) - # def get_parser(self): - # return self.parser - def execute(self, context: KeeperParams, **kwargs): """Execute biometric update-name command""" def _update_name(): @@ -147,4 +144,4 @@ def _report_update_results(self, result, credential, new_name): print(f"Old Name: {credential['name']}") print(f"New Name: {new_name}") print(f"Message: {result['message']}") - print("=" * 30) \ No newline at end of file + print("=" * 30) diff --git a/keepercli-package/src/keepercli/commands/enterprise_user.py b/keepercli-package/src/keepercli/commands/enterprise_user.py index 67f08cef..b6a2044b 100644 --- a/keepercli-package/src/keepercli/commands/enterprise_user.py +++ b/keepercli-package/src/keepercli/commands/enterprise_user.py @@ -1054,22 +1054,21 @@ def _get_ecc_data_keys(self, context: KeeperParams, user_ids: Set[int]) -> Dict[ data_key_rq.enterpriseUserId.extend(user_ids) data_key_rs = context.auth.execute_auth_rest( GET_ENTERPRISE_USER_DATA_KEY_ENDPOINT, data_key_rq, - response_type=enterprise_pb2.EnterpriseUserDataKeys) + response_type=APIRequest_pb2.EnterpriseUserIdDataKeyPair) - for key in data_key_rs.keys: - enc_data_key = key.userEncryptedDataKey - if enc_data_key: - try: - ephemeral_public_key = ec.EllipticCurvePublicKey.from_encoded_point( - curve, enc_data_key[:ECC_PUBLIC_KEY_LENGTH]) - shared_key = ecc_private_key.exchange(ec.ECDH(), ephemeral_public_key) - digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) - digest.update(shared_key) - enc_key = digest.finalize() - data_key = utils.crypto.decrypt_aes_v2(enc_data_key[ECC_PUBLIC_KEY_LENGTH:], enc_key) - data_keys[key.enterpriseUserId] = data_key - except Exception as e: - logger.debug(e) + enc_data_key = data_key_rs.encryptedDataKey + if enc_data_key: + try: + ephemeral_public_key = ec.EllipticCurvePublicKey.from_encoded_point( + curve, enc_data_key[:ECC_PUBLIC_KEY_LENGTH]) + shared_key = ecc_private_key.exchange(ec.ECDH(), ephemeral_public_key) + digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) + digest.update(shared_key) + enc_key = digest.finalize() + data_key = utils.crypto.decrypt_aes_v2(enc_data_key[ECC_PUBLIC_KEY_LENGTH:], enc_key) + data_keys[data_key_rs.enterpriseUserId] = data_key + except Exception as e: + logger.debug(e) return data_keys diff --git a/keepercli-package/src/keepercli/commands/pam/debug/__init__.py b/keepercli-package/src/keepercli/commands/pam/debug/__init__.py new file mode 100644 index 00000000..d6f6ac91 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/debug/__init__.py @@ -0,0 +1,17 @@ +from __future__ import annotations +from keepersdk.helpers.keeper_dag.dag_utils import value_to_boolean +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ....params import KeeperParams + from keepersdk.helpers.keeper_dag.connection import ConnectionBase + + +def get_connection(context: KeeperParams) -> ConnectionBase: + if value_to_boolean(os.environ.get("USE_LOCAL_DAG", False)) is False: + from keepersdk.helpers.keeper_dag.connection.commander import Connection as CommanderConnection + return CommanderConnection(context=context) + else: + from keepersdk.helpers.keeper_dag.connection.local import Connection as LocalConnection + return LocalConnection() \ No newline at end of file diff --git a/keepercli-package/src/keepercli/commands/pam/debug/debug_acl.py b/keepercli-package/src/keepercli/commands/pam/debug/debug_acl.py new file mode 100644 index 00000000..b8b8b219 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/debug/debug_acl.py @@ -0,0 +1,156 @@ + +import logging +import re +import argparse +import time + +from ....params import KeeperParams +from keepersdk.helpers.keeper_dag.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from keepersdk.helpers.keeper_dag.dag_types import UserAcl +from keepersdk.helpers.keeper_dag.dag import EdgeType +from keepersdk.helpers.keeper_dag.dag_types import DiscoveryObject +from keepersdk.helpers.keeper_dag.infrastructure import Infrastructure +from keepersdk.helpers.keeper_dag.user_service import UserService +from ..discovery.__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from .... import api + +logger = api.get_logger() + +class PAMDebugACLCommand(PAMGatewayActionDiscoverCommandBase): + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action debug acl') + PAMDebugACLCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID.') + parser.add_argument('--configuration-uid', "-c", required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--user-uid', '-u', required=True, dest='user_uid', action='store', + help='User UID.') + parser.add_argument('--parent-uid', '-r', required=True, dest='parent_uid', action='store', + help='Resource or Configuration UID.') + parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', + help='GraphSync debug level. Default is 0', type=int, default=0) + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + user_uid = kwargs.get("user_uid") + parent_uid = kwargs.get("parent_uid") + debug_level = int(kwargs.get("debug_level", 0)) + + logger.info("") + + configuration_uid = kwargs.get('configuration_uid') + + gateway_context = GatewayContext.from_gateway(context=context, + gateway=gateway, + configuration_uid=configuration_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + record_link = RecordLink(record=gateway_context.configuration, + context=context, + logger=logger, + debug_level=debug_level) + + user_record = context.vault.vault_data.load_record(user_uid) + if user_record is None: + logger.error(f"The user record does not exists.") + return + + logger.info(f"The user record is {user_record.title}") + + if user_record.record_type != PAM_USER: + logger.error(f"The user record is not a PAM User record.") + return + + parent_record = context.vault.vault_data.load_record(parent_uid) + if parent_record is None: + logger.error(f"The parent record does not exists.") + return + + logger.info(f"The parent record is {parent_record.title}") + + if parent_record.record_type.startswith("pam") is False: + logger.error(f"The parent record is not a PAM record.") + return + + if parent_record.record_type == PAM_USER: + logger.error(f"The parent record cannot be a PAM User record.") + return + + parent_is_config = parent_record.record_type.endswith("Configuration") + + # Get the ACL between the user and the parent. + # It might not exist. + acl_exists = True + acl = record_link.get_acl(user_uid, parent_uid) + if acl is None: + logger.info("No existing ACL, creating an ACL.") + acl = UserAcl() + acl_exists = False + + # Make sure the ACL for cloud user is set. + if parent_is_config is True: + logger.info("Is an IAM user.") + acl.is_iam_user = True + + rl_parent_vertex = record_link.dag.get_vertex(parent_uid) + if rl_parent_vertex is None: + logger.info("Parent record linking vertex did not exists, creating one.") + rl_parent_vertex = record_link.dag.add_vertex(parent_uid) + + rl_user_vertex = record_link.dag.get_vertex(user_uid) + if rl_user_vertex is None: + logger.info("User record linking vertex did not exists, creating one.") + rl_user_vertex = record_link.dag.add_vertex(user_uid) + + has_admin_uid = record_link.get_admin_record_uid(parent_uid) + if has_admin_uid is not None: + logger.info("Parent record already has an admin.") + else: + logger.info("Parent record does not have an admin.") + + belongs_to_vertex = record_link.acl_has_belong_to_record_uid(user_uid) + if belongs_to_vertex is None: + logger.info("User record does not belong to any resource, or provider.") + else: + if not belongs_to_vertex.active: + logger.info("User record belongs to an inactive parent.") + else: + logger.info("User record belongs to another record.") + + logger.info("") + + while True: + res = input(f"Does this user belong to {parent_record.title} Y/N >").lower() + if res == "y": + acl.belongs_to = True + break + elif res == "n": + acl.belongs_to = False + break + + if has_admin_uid is None: + while True: + res = input(f"Is this user the admin of {parent_record.title} Y/N >").lower() + if res == "y": + acl.is_admin = True + break + elif res == "n": + acl.is_admin = False + break + + try: + record_link.belongs_to(user_uid, parent_uid, acl=acl) + record_link.save() + logger.info(f"Updated/added ACL between {user_record.title} and " + f"{parent_record.title}") + except Exception as err: + logger.error(f"Could not update ACL: {err}") diff --git a/keepercli-package/src/keepercli/commands/pam/debug/debug_gateway.py b/keepercli-package/src/keepercli/commands/pam/debug/debug_gateway.py new file mode 100644 index 00000000..4e958fa0 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/debug/debug_gateway.py @@ -0,0 +1,99 @@ +import argparse + +from ....params import KeeperParams +from keepersdk.helpers.keeper_dag.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from keepersdk.helpers.keeper_dag.infrastructure import Infrastructure +from keepersdk.helpers.keeper_dag.user_service import UserService +from ..discovery.__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from .... import api +from .debug_graph import PAMDebugGraphCommand + +logger = api.get_logger() + +class PAMDebugGatewayCommand(PAMGatewayActionDiscoverCommandBase): + + type_name_map = { + PAM_USER: "PAM User", + PAM_MACHINE: "PAM Machine", + PAM_DATABASE: "PAM Database", + PAM_DIRECTORY: "PAM Directory", + } + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action debug gateway') + PAMDebugGatewayCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID') + parser.add_argument('--configuration-uid', "-c", required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + debug_level = kwargs.get("debug_level", False) + + configuration_uid = kwargs.get('configuration_uid') + vault = context.vault + + gateway_context = GatewayContext.from_gateway(context=context, + gateway=gateway, + configuration_uid=configuration_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + infra = Infrastructure(record=gateway_context.configuration, vault=vault, fail_on_corrupt=False) + infra.load() + + record_link = RecordLink(record=gateway_context.configuration, vault=vault, fail_on_corrupt=False) + user_service = UserService(record=gateway_context.configuration, vault=vault, fail_on_corrupt=False) + + if gateway_context is None: + logger.error(f"Cannot get gateway information. Gateway may not be up.") + return + + logger.info("") + logger.info(f"Gateway Information") + logger.info(f" Gateway UID: {gateway_context.gateway_uid}") + logger.info(f" Gateway Name: {gateway_context.gateway_name}") + if gateway_context.configuration is not None: + logger.info(f" Configuration UID: {gateway_context.configuration_uid}") + logger.info(f" Configuration Title: {gateway_context.configuration.title}") + logger.info(f" Configuration Key Bytes Hex: {gateway_context.configuration.record_key.hex()}") + else: + logger.error(f"The gateway appears to not have a configuration.") + logger.info("") + + graph = PAMDebugGraphCommand() + + if infra.dag.has_graph is True: + logger.info(f"Infrastructure Graph") + graph.do_list(context=context, gateway_context=gateway_context, graph_type="infra", debug_level=debug_level, + indent=1) + else: + logger.error(f"The gateway configuration does not have a infrastructure graph.") + + logger.info("") + + if record_link.dag.has_graph is True: + logger.info(f"Record Linking Graph") + graph.do_list(context=context, gateway_context=gateway_context, graph_type="rl", debug_level=debug_level, + indent=1) + else: + logger.error(f"The gateway configuration does not have a record linking graph.") + + logger.info("") + + if user_service.dag.has_graph is True: + logger.info(f"User to Service/Task Graph") + graph.do_list(context=context, gateway_context=gateway_context, graph_type="service", debug_level=debug_level, + indent=1) + else: + logger.error(f"The gateway configuration does not have a user to service/task graph.") + + logger.info("") diff --git a/keepercli-package/src/keepercli/commands/pam/debug/debug_graph.py b/keepercli-package/src/keepercli/commands/pam/debug/debug_graph.py new file mode 100644 index 00000000..ed9ee5a0 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/debug/debug_graph.py @@ -0,0 +1,676 @@ +import argparse +import logging +from typing import Optional + +from keepersdk.helpers.keeper_dag.jobs import Jobs +from keepersdk.helpers.keeper_dag.process import VERTICES_SORT_MAP, DiscoveryObject + +from ....params import KeeperParams +from keepersdk.helpers.keeper_dag.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from keepersdk.helpers.keeper_dag.infrastructure import Infrastructure +from keepersdk.helpers.keeper_dag.user_service import UserService +from keepersdk.helpers.keeper_dag.dag_types import DiscoveryUser, DiscoveryDirectory, DiscoveryMachine, DiscoveryDatabase, JobContent +from keepersdk.helpers.keeper_dag.dag import DAGVertex, DAG +from keepersdk.helpers.keeper_dag.constants import DIS_INFRA_GRAPH_ID, RECORD_LINK_GRAPH_ID, USER_SERVICE_GRAPH_ID, DIS_JOBS_GRAPH_ID +from keepersdk.helpers.keeper_dag.dag_sort import sort_infra_vertices +from ..discovery.__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from .... import api +from . import get_connection + +logger = api.get_logger() + +class PAMDebugGraphCommand(PAMGatewayActionDiscoverCommandBase): + + NO_RECORD = "NO RECORD" + OTHER = "OTHER" + + mapping = { + PAM_USER: {"order": 1, "sort": "_sort_name", "item": DiscoveryUser, "key": "user"}, + PAM_DIRECTORY: {"order": 1, "sort": "_sort_name", "item": DiscoveryDirectory, "key": "host_port"}, + PAM_MACHINE: {"order": 2, "sort": "_sort_host", "item": DiscoveryMachine, "key": "host"}, + PAM_DATABASE: {"order": 3, "sort": "_sort_host", "item": DiscoveryDatabase, "key": "host_port"}, + } + + graph_id_map = { + "infra": DIS_INFRA_GRAPH_ID, + "rl": RECORD_LINK_GRAPH_ID, + "service": USER_SERVICE_GRAPH_ID, + "jobs": DIS_JOBS_GRAPH_ID + } + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action debug graph') + PAMDebugGraphCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID.') + parser.add_argument('--configuration-uid', "-c", required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--type', '-t', required=True, choices=['infra', 'rl', 'service', 'jobs'], + dest='graph_type', action='store', help='Graph type', default='infra') + parser.add_argument('--raw', required=False, dest='raw', action='store_true', + help='Render raw graph. Will render corrupt graphs.') + parser.add_argument('--list', required=False, dest='do_text_list', action='store_true', + help='List items in a list.') + parser.add_argument('--render', required=False, dest='do_render', action='store_true', + help='Render a graph') + parser.add_argument('--file', '-f', required=False, dest='filepath', action='store', + default="keeper_graph", help='Base name for the graph file.') + parser.add_argument('--format', required=False, choices=['raw', 'dot', 'twopi', 'patchwork'], + dest='format', default="dot", action='store', help='The format of the graph.') + parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', + help='GraphSync debug level. Default is 0', type=int, default=0) + + def _do_text_list_infra(self, context: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, + indent: int = 0): + + infra = Infrastructure(record=gateway_context.configuration, context=context, logger=logging, + debug_level=debug_level) + infra.load(sync_point=0) + + try: + configuration = infra.get_root.has_vertices()[0] + except (Exception,): + logger.error(f"Could not find the configuration in the infrastructure graph. " + f"Has discovery been run for this gateway?") + + return + + line_start = { + 0: "", + 1: "* ", + 2: "- ", + } + + def _handle(current_vertex: DAGVertex, indent: int = 0, last_record_type: Optional[str] = None): + + if not current_vertex.active: + return + + pad = "" + if indent > 0: + pad = "".ljust(4 * indent, ' ') + + text = "" + ls = line_start.get(indent, " ") + + if not current_vertex.active: + text += f"{pad}{current_vertex.uid} (Inactive)" + elif not current_vertex.corrupt: + current_content = DiscoveryObject.get_discovery_object(current_vertex) + if current_content.record_uid is None: + text += f"{pad}{ls}{current_vertex.uid}; {current_content.title} does not have a record." + else: + record = context.vault.vault_data.load_record(current_content.record_uid) + if record is not None: + text += f"{pad}{ls}" + (f"{current_vertex.uid}; {record.title}; {record.record_uid}") + else: + text += f"{pad}{ls}" + (f"{current_vertex.uid}; {current_content.title}; have record uid, record does not exists, might have to sync.") + else: + text += f"{pad}{current_vertex.uid} (Corrupt)" + + logger.info(text) + + record_type_to_vertices_map = sort_infra_vertices(current_vertex) + # Process the record type by their map order in ascending order. + + # Sort the record types by their order in the constant. + # 'order' is an int. + for record_type in sorted(record_type_to_vertices_map, key=lambda i: VERTICES_SORT_MAP[i]['order']): + for vertex in record_type_to_vertices_map[record_type]: + if last_record_type is None or last_record_type != record_type: + if indent == 0: + logger.info(f"{pad} {record_type}") + last_record_type = record_type + + _handle(vertex, indent=indent+1) + + logger.info("") + _handle(configuration, indent=indent) + logger.info("") + + def _do_text_list_rl(self, context: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, + indent: int = 0): + + logger.info("") + + pad = "" + if indent > 0: + pad = "".ljust(4 * indent, ' ') + + record_link = RecordLink(record=gateway_context.configuration, + context=context, + logger=logging, + debug_level=debug_level) + configuration = record_link.dag.get_root + + record = context.vault.vault_data.load_record(record_uid=configuration.uid) + if record is None: + logger.error(f"Configuration record does not exists.") + return + + logger.info(f"{pad}{record.record_type}, {record.title}, {record.record_uid}") + + if configuration.has_data: + try: + data = configuration.content_as_dict + logger.info(f"{pad} . data") + for k, v in data.items(): + logger.info(f"{pad} + {k} = {v}") + except Exception as err: + logger.error(f"{pad} ! data not JSON: {err}") + + def _group(configuration_vertex: DAGVertex) -> dict: + + group = { + PAM_USER: [], + PAM_DIRECTORY: [], + PAM_DATABASE: [], + PAM_MACHINE: [], + PAMDebugGraphCommand.NO_RECORD: [], + PAMDebugGraphCommand.OTHER: [] + } + + for vertex in configuration_vertex.has_vertices(): + record = context.vault.vault_data.load_record(record_uid=vertex.uid) + if record is None: + group[PAMDebugGraphCommand.NO_RECORD].append({ + "v": vertex + }) + continue + rt = record.record_type + if rt not in group: + rt = PAMDebugGraphCommand.OTHER + group[rt].append({ + "v": vertex, + "r": record + }) + + return group + + group = _group(configuration) + + for record_type in [PAM_USER, PAM_DIRECTORY, PAM_MACHINE, PAM_DATABASE]: + if len(group[record_type]) > 0: + logger.info(f"{pad} {record_type}") + for item in group[record_type]: + vertex = item.get("v") + record = item.get("r") + text = f"{record.title}; {record.record_uid}" + if not vertex.active: + text += " Inactive" + logger.info(f"{pad} * {text}") + + # These are cloud users + if record_type == PAM_USER: + acl = record_link.get_acl(vertex.uid, configuration.uid) + if acl is None: + logger.info(f"{pad} missing ACL") + else: + if acl.is_iam_user: + logger.info(f"{pad} . is IAM user") + if acl.is_admin: + logger.info(f"{pad} . is the Admin") + if acl.belongs_to: + logger.info(f"{pad} . belongs to this resource") + else: + logger.info(f"{pad} . looks like directory user") + + if acl.rotation_settings: + if acl.rotation_settings.noop: + logger.info(f"{pad} . is a NOOP") + if acl.rotation_settings.disabled: + logger.info(f"{pad} . rotation is disabled") + + if (acl.rotation_settings.saas_record_uid_list is not None + and len(acl.rotation_settings.saas_record_uid_list) > 0): + logger.info(f"{pad} . has SaaS rotation: " + f"{acl.rotation_settings.saas_record_uid_list[0]}") + + continue + + if vertex.has_data: + try: + data = vertex.content_as_dict + logger.info(f"{pad} . data") + for k, v in data.items(): + logger.info(f"{pad} + {k} = {v}") + except Exception as err: + logger.error(f"{pad} ! data not JSON: {err}") + + children = vertex.has_vertices() + if len(children) > 0: + bad = [] + for child in children: + child_record = context.vault.vault_data.load_record(record_uid=child.uid) + if child_record is None: + if child.active: + bad.append(f"- Record UID {child.uid} does not exists.") + continue + else: + logger.info(f"{pad} - {child_record.title}; {child_record.record_uid}") + acl = record_link.get_acl(child.uid, vertex.uid) + if acl is None: + logger.info(f"{pad} missing ACL") + else: + if acl.is_admin: + logger.info(f"{pad} . is the Admin") + if acl.belongs_to: + logger.info(f"{pad} . belongs to this resource") + else: + logger.info(f"{pad} . looks like directory user") + + if child.has_data: + try: + data = child.content_as_dict + logger.info(f"{pad} . data") + for k, v in data.items(): + logger.info(f"{pad} + {k} = {v}") + except Exception as err: + logger.info(f"{pad} ! data not JSON: {err}") + for i in bad: + logger.error(f"{pad} {i}") + + if len(group[PAMDebugGraphCommand.OTHER]) > 0: + logger.info(f"{pad} Other PAM Types") + for item in group[PAMDebugGraphCommand.OTHER]: + vertex = item.get("v") + record = item.get("r") + text = f"{record.record_type}; {record.title}; {record.record_uid}" + if not vertex.active: + text += " Inactive" + logger.info(f"{pad} * {text}") + + if len(group[PAMDebugGraphCommand.NO_RECORD]) > 0: + + # TODO: Check the infra graph for information + logger.info(f"{pad} In Graph, No Vault Record") + for item in group[PAMDebugGraphCommand.NO_RECORD]: + vertex = item.get("v") + logger.info(f"{pad} * {vertex.uid}") + + def _do_text_list_service(self, context: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, + indent: int = 0): + + user_service = UserService(record=gateway_context.configuration, context=context, logger=logging, + debug_level=debug_level) + configuration = user_service.dag.get_root + + def _handle(current_vertex: DAGVertex, parent_vertex: Optional[DAGVertex] = None, indent: int = 0): + + pad = "" + if indent > 0: + pad = "".ljust(2 * indent, ' ') + "* " + + record = context.vault.vault_data.load_record(record_uid=current_vertex.uid) + if record is None: + if not current_vertex.active: + logger.info(f"{pad}Record {current_vertex.uid} does not exists, inactive in the graph.") + else: + logger.info(f"{pad}Record {current_vertex.uid} does not exists, active in the graph.") + return + elif not current_vertex.active: + logger.info(f"{pad}{record.record_type}, {record.title}, {record.record_uid} exists, " + "inactive in the graph.") + return + + acl_text = "" + if parent_vertex is not None: + acl = user_service.get_acl(resource_uid=parent_vertex.uid, user_uid=current_vertex.uid) + if acl is not None: + acl_text = "No Services" + acl_parts = [] + if acl.is_service: + acl_parts.append("Service") + if acl.is_task: + acl_parts.append("Task") + if acl.is_iis_pool: + acl_parts.append("Task") + if len(acl_parts) > 0: + acl_text = ", ".join(acl_parts) + acl_text = f" -> {acl_text}" + + logger.info(f"{pad}{record.record_type}, {record.title}, {record.record_uid}{acl_text}") + + for vertex in current_vertex.has_vertices(): + _handle(current_vertex=vertex, parent_vertex=current_vertex, indent=indent+1) + + _handle(current_vertex=configuration, parent_vertex=None, indent=indent) + + def _do_text_list_jobs(self, context: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, + indent: int = 0): + + infra = Infrastructure(record=gateway_context.configuration, context=context, logger=logging, + debug_level=debug_level, fail_on_corrupt=False) + infra.load(sync_point=0) + + pad = "" + if indent > 0: + pad = "".ljust(2 * indent, ' ') + "* " + + conn = get_connection(context=context) + graph_sync = DAG(conn=conn, record=gateway_context.configuration, logger=logging, debug_level=debug_level, + graph_id=DIS_JOBS_GRAPH_ID) + graph_sync.load(0) + configuration = graph_sync.get_root + vertices = configuration.has_vertices() + if len(vertices) == 0: + logger.error(f"The jobs graph has not been initialized. Only has root vertex.") + return + + vertex = vertices[0] + if not vertex.has_data: + logger.error(f"The job vertex does not contain any data") + return + + current_json = vertex.content_as_str + if current_json is None: + logger.error(f"The current job vertex content is None") + return + + content = JobContent.model_validate_json(current_json) + logger.info(f"{pad}Active Job ID: {content.active_job_id}") + logger.info("") + logger.info(f"{pad}History") + logger.info("") + for job in content.job_history: + logger.info(f"{pad} --------------------------------------") + logger.info(f"{pad} Job Id: {job.job_id}") + logger.info(f"{pad} Started: {job.start_ts_str}") + logger.info(f"{pad} Ended: {job.end_ts_str}") + logger.info(f"{pad} Duration: {job.duration_sec_str}") + logger.info(f"{pad} Infra Sync Point: {job.sync_point}") + if job.success: + logger.info(f"{pad} Status: Success") + else: + logger.info(f"{pad} Status: Fail") + if job.error is not None: + logger.info(f"{pad} Error: {job.error}") + + logger.info("") + + if job.delta is None: + logger.error(f"{pad}The job is missing a delta, never finished discovery.") + else: + if len(job.delta.added) > 0: + logger.info(f"{pad} Added") + for added in job.delta.added: + vertex = infra.dag.get_vertex(added.uid) + if vertex is None: + logger.info(f"{pad} * Vertex {added.uid} does not exists.") + else: + if not vertex.active: + logger.info(f"{pad} * Vertex {added.uid} is inactive.") + elif vertex.corrupt: + logger.info(f"{pad} * Vertex {added.uid} is corrupt.") + else: + content = DiscoveryObject.get_discovery_object(vertex) + logger.info(f"{pad} * {content.description}; Record UID: {content.record_uid}") + logger.info("") + + if len(job.delta.changed) > 0: + logger.info(f"{pad} Changed") + for changed in job.delta.changed: + vertex = infra.dag.get_vertex(changed.uid) + if vertex is None: + logger.info(f"{pad} * Vertex {changed.uid} does not exists.") + else: + if not vertex.active: + logger.info(f"{pad} * Vertex {changed.uid} is inactive.") + elif vertex.corrupt: + logger.info(f"{pad} * Vertex {changed.uid} is corrupt.") + else: + content = DiscoveryObject.get_discovery_object(vertex) + logger.info(f"{pad} * {content.description}; Record UID: {content.record_uid}") + if changed.changes is not None: + for k, v in changed.changes.items(): + logger.info(f"{pad} {k} = {v}") + logger.info("") + + if len(job.delta.deleted) > 0: + logger.info(f"{pad} Deleted") + for deleted in job.delta.deleted: + logger.info(f"{pad} * Removed vertex {deleted.uid}.") + logger.info("") + + def _do_render_infra(self, context: KeeperParams, gateway_context: GatewayContext, filepath: str, graph_format: str, + debug_level: int = 0): + + infra = Infrastructure(record=gateway_context.configuration, context=context, logger=logging, + debug_level=debug_level) + infra.load(sync_point=0) + + logger.info("") + dot_instance = infra.to_dot( + graph_type=graph_format if graph_format != "raw" else "dot", + show_only_active_vertices=False, + show_only_active_edges=False + ) + if graph_format == "raw": + logger.info(dot_instance) + else: + try: + dot_instance.render(filepath) + logger.info(f"Infrastructure graph rendered to {filepath}") + except Exception as err: + logger.error(f"Could not generate graph: {err}") + raise err + logger.info("") + + def _do_render_rl(self, context: KeeperParams, gateway_context: GatewayContext, filepath: str, graph_format: str, + debug_level: int = 0): + + rl = RecordLink(record=gateway_context.configuration, + context=context, + logger=logging, + debug_level=debug_level) + + logger.info("") + dot_instance = rl.to_dot( + graph_type=graph_format if graph_format != "raw" else "dot", + show_only_active_vertices=False, + show_only_active_edges=False + ) + if graph_format == "raw": + logger.info(dot_instance) + else: + try: + dot_instance.render(filepath) + logger.info(f"Record linking graph rendered to {filepath}") + except Exception as err: + logger.error(f"Could not generate graph: {err}") + raise err + logger.info("") + + def _do_render_service(self, context: KeeperParams, gateway_context: GatewayContext, filepath: str, + graph_format: str, debug_level: int = 0): + + service = UserService(record=gateway_context.configuration, context=context, logger=logging, + debug_level=debug_level) + + logger.info("") + dot_instance = service.to_dot( + graph_type=graph_format if graph_format != "raw" else "dot", + show_only_active_vertices=False, + show_only_active_edges=False + ) + if graph_format == "raw": + logger.info(dot_instance) + else: + try: + dot_instance.render(filepath) + logger.info(f"User service/tasks graph rendered to {filepath}") + except Exception as err: + logger.error(f"Could not generate graph: {err}") + raise err + logger.info("") + + def _do_render_jobs(self, context: KeeperParams, gateway_context: GatewayContext, filepath: str, + graph_format: str, debug_level: int = 0): + + jobs = Jobs(record=gateway_context.configuration, context=context, logger=logging, debug_level=debug_level) + + logger.info("") + dot_instance = jobs.dag.to_dot() + if graph_format == "raw": + logger.info(dot_instance) + else: + try: + dot_instance.render(filepath) + logger.info(f"Job graph rendered to {filepath}") + except Exception as err: + logger.error(f"Could not generate graph: {err}") + raise err + logger.info("") + + def _do_raw_text_list(self, context: KeeperParams, gateway_context: GatewayContext, graph_id: int = 0, + debug_level: int = 0): + + logging.debug(f"loading graph id {graph_id}, for record uid {gateway_context.configuration.record_uid}") + + conn = get_connection(context=context) + dag = DAG(conn=conn, record=gateway_context.configuration, graph_id=graph_id, fail_on_corrupt=False, + logger=logging, debug_level=debug_level) + dag.load(sync_point=0) + logger.info("") + if dag.is_corrupt is True: + logger.error(f"The graph is corrupt at Vertex UIDs: {', '.join(dag.corrupt_uids)}") + logger.error("") + + logger.debug("DAG DOT -------------------------------") + logger.debug(str(dag.to_dot())) + logger.debug("DAG DOT -------------------------------") + + line_start = { + 0: "", + 1: "* ", + 2: "- ", + 3: ". ", + } + + def _handle(current_vertex: DAGVertex, last_vertex: Optional[DAGVertex] = None, indent: int = 0): + + pad = "" + if indent > 0: + pad = "".ljust(4 * indent, ' ') + + ls = line_start.get(indent, " ") + text = f"{pad}{ls}{current_vertex.uid}" + + edge_types = [] + if last_vertex is not None: + for edge in current_vertex.edges: # type: DAGEdge + if not edge.active: + continue + if edge.head_uid == last_vertex.uid: + edge_types.append(edge.edge_type.value) + if len(edge_types) > 0: + text += f"; edges: {', '.join(edge_types)}" + + if not current_vertex.active: + text += " Inactive" + if current_vertex.corrupt: + text += " Corrupt" + + logger.info(text) + + if not current_vertex.active: + logger.debug(f"vertex {current_vertex.uid} is not active, will not get children.") + return + + vertices = current_vertex.has_vertices() + if len(vertices) == 0: + logger.debug(f"vertex {current_vertex.uid} does not have any children.") + return + + for vertex in vertices: + _handle(vertex, current_vertex, indent=indent + 1) + + logger.info("") + _handle(dag.get_root) + logger.info("") + + def _do_raw_render_graph(self, context: KeeperParams, gateway_context: GatewayContext, filepath: str, + graph_format: str, graph_id: int = 0, debug_level: int = 0): + + conn = get_connection(context=context) + dag = DAG(conn=conn, record=gateway_context.configuration, graph_id=graph_id, fail_on_corrupt=False, + logger=logging, debug_level=debug_level) + dag.load(sync_point=0) + dot = dag.to_dot(graph_format=graph_format) + if graph_format == "raw": + logger.info(dot) + else: + try: + dot.render(filepath) + logger.info(f"Graph rendered to {filepath}") + except Exception as err: + logger.error(f"Could not generate graph: {err}") + raise err + + logger.info("") + + def do_list(self, context: KeeperParams, gateway_context: GatewayContext, graph_type: str, debug_level: int = 0, + indent: int = 0): + list_func = getattr(self, f"_do_text_list_{graph_type}") + list_func(context=context, + gateway_context=gateway_context, + debug_level=debug_level, + indent=indent) + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + raw = kwargs.get("raw", False) + graph_type = kwargs.get("graph_type") + do_text_list = kwargs.get("do_text_list") + do_render = kwargs.get("do_render") + debug_level = int(kwargs.get("debug_level", 0)) + + configuration_uid = kwargs.get('configuration_uid') + + vault = context.vault + + gateway_context = GatewayContext.from_gateway(vault=vault, + gateway=gateway, + configuration_uid=configuration_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + if raw: + if do_text_list: + self._do_raw_text_list(context=context, + gateway_context=gateway_context, + graph_id=PAMDebugGraphCommand.graph_id_map.get(graph_type), + debug_level=debug_level) + if do_render: + filepath = kwargs.get("filepath") + graph_format = kwargs.get("format") + self._do_raw_render_graph(context=context, + gateway_context=gateway_context, + filepath=filepath, + graph_format=graph_format, + graph_id=PAMDebugGraphCommand.graph_id_map.get(graph_type), + debug_level=debug_level) + else: + if do_text_list: + self.do_list( + context=context, + gateway_context=gateway_context, + graph_type=graph_type, + debug_level=debug_level + ) + if do_render: + filepath = kwargs.get("filepath") + graph_format = kwargs.get("format") + render_func = getattr(self, f"_do_render_{graph_type}") + render_func(context=context, + gateway_context=gateway_context, + filepath=filepath, + graph_format=graph_format, + debug_level=debug_level) diff --git a/keepercli-package/src/keepercli/commands/pam/debug/debug_info.py b/keepercli-package/src/keepercli/commands/pam/debug/debug_info.py new file mode 100644 index 00000000..a386958b --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/debug/debug_info.py @@ -0,0 +1,552 @@ +import re +import argparse +import time + +from ....params import KeeperParams +from keepersdk.helpers.keeper_dag.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from keepersdk.helpers.keeper_dag.dag_types import UserAcl +from keepersdk.helpers.keeper_dag.dag import EdgeType +from keepersdk.helpers.keeper_dag.dag_types import DiscoveryObject +from keepersdk.helpers.keeper_dag.infrastructure import Infrastructure +from keepersdk.helpers.keeper_dag.user_service import UserService +from ..discovery.__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from .... import api + +logger = api.get_logger() + + +class PAMDebugInfoCommand(PAMGatewayActionDiscoverCommandBase): + + type_name_map = { + PAM_USER: "PAM User", + PAM_MACHINE: "PAM Machine", + PAM_DATABASE: "PAM Database", + PAM_DIRECTORY: "PAM Directory", + } + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action debug info') + PAMDebugInfoCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--record-uid', '-i', required=True, dest='record_uid', action='store', + help='Keeper PAM record UID.') + + def execute(self, context: KeeperParams, **kwargs): + + record_uid = kwargs.get("record_uid") + vault = context.vault + record = vault.vault_data.load_record(record_uid) + if record is None: + logger.error(f"Record does not exists.") + return + + if record.record_type not in ["pamUser", "pamMachine", "pamDatabase", "pamDirectory"]: + if re.search(r'^pam.+Configuration$', record.record_type) is None: + logger.error(f"The record is a {record.record_type}. This is not a PAM record.") + return + + resource_uid = None + controller_uid = None + + record_rotation = params.record_rotation_cache.get(record_uid) + + # Rotation setting don't exist, check each configuration for an active record. + if record_rotation is None: + logger.warning(f"PAM record does not have protobuf rotation settings, " + f"checking all configurations.") + + # Get all the PAM configuration records in the Vault; configurations are version 6 + configuration_records = GatewayContext.get_configuration_records(vault=vault) + if len(configuration_records) == 0: + logger.error(f"Cannot find any PAM configuration records in the Vault") + + for configuration_record in configuration_records: + + record_link = RecordLink(record=configuration_record, vault=vault) + record_vertex = record_link.dag.get_vertex(record.record_uid) + if record_vertex is not None and record_vertex.active is True: + controller_uid = configuration_record.record_uid + break + if controller_uid is None: + logger.error(f"Could not find the record in any record linking graph; " + f"checked all configuration records.") + return + + # Else just get information from the rotation settings + else: + + controller_uid = record_rotation.get("configuration_uid") + if controller_uid is None: + logger.error(f"Record does not have the PAM Configuration set.") + return + + resource_uid = record_rotation.get("resource_uid") + + configuration_record = vault.vault_data.load_record(controller_uid) + if configuration_record is None: + logger.error(f"The configuration record {controller_uid} does not exist.") + return + + gateway_context = GatewayContext.from_configuration_uid(vault=vault, configuration_uid=controller_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway for configuration record.{controller_uid}") + return + + infra = Infrastructure(record=configuration_record, vault=vault) + infra.load() + record_link = RecordLink(record=configuration_record, vault=vault) + user_service = UserService(record=configuration_record, vault=vault) + + logger.info("") + logger.info(f"Record Information") + logger.info(f" {('Record UID')}: {record_uid}") + logger.info(f" {('Record Title')}: {record.title}") + logger.info(f" {('Record Type')}: {record.record_type}") + logger.info(f" {('Configuration UID')}: {configuration_record.record_uid}") + logger.info(f" {('Configuration Key Bytes Hex')}: {configuration_record.record_key.hex()}") + if resource_uid is not None: + logger.info(f" {('Resource UID')}: {resource_uid}") + + if gateway_context is not None: + logger.info(f" {('Gateway Name')}: {gateway_context.gateway_name}") + logger.info(f" {('Gateway UID')}: {gateway_context.gateway_uid}") + else: + logger.error(f" {('Cannot get gateway information. Gateway may not be up.')}") + logger.info("") + + def _print_field(f): + if f.type == "password": + display_value = f"Password is set" + if f.value == 0 or len(f.value) == 0: + display_value = f"Password IS NOT set" + logger.info(f" * Type: {f.type}, Label: {f.label or 'NO LABEL'}, " + f"Value(s): {display_value}") + elif f.label == "privatePEMKey": + display_value = f"Private Key is set" + if field.value == 0 or len(f.value) == 0: + display_value = f"Private Key IS NOT set" + logger.info(f" * Type: {f.type}, Label: {f.label or 'NO LABEL'}, " + f"Value(s): {display_value}") + elif f.type == "secret": + display_value = f"Secret value is set" + if field.value == 0 or len(f.value) == 0: + display_value = f"Secret value IS NOT set" + logger.info(f" * Type: {f.type}, Label: {f.label or 'NO LABEL'}, " + f"Value(s): {display_value}") + else: + logger.info(f" * Type: {f.type}, Label: {f.label or 'NO LABEL'}, " + f"Value(s): {f.value}") + + logger.info(f"Fields") + logger.info(f" Record Type Fields") + if record.fields is not None and len(record.fields) > 0: + for field in record.fields: + _print_field(field) + else: + logger.error(f" Record does not have record type fields!") + logger.info("") + logger.info(f" Custom Fields") + if record.custom is not None and len(record.custom) > 0: + for field in record.custom: + _print_field(field) + else: + logger.error(f" Record does not have custom fields.") + logger.info("") + + discovery_vertices = infra.dag.search_content({"record_uid": record.record_uid}) + record_vertex = record_link.dag.get_vertex(record.record_uid) + + if record_vertex is not None: + logger.info(f"Record Linking") + record_parent_vertices = record_vertex.belongs_to_vertices() + logger.info(f" Parent Records") + if len(record_parent_vertices) > 0: + for record_parent_vertex in record_parent_vertices: + + parent_record = vault.vault_data.load_record( + record_parent_vertex.uid) + if parent_record is None: + logger.error(f" * Parent record {record_parent_vertex.uid} " + f"does not exists.") + continue + + acl_edge = record_vertex.get_edge(record_parent_vertex, EdgeType.ACL) + if acl_edge is not None: + acl_content = acl_edge.content_as_object(UserAcl) # type: UserAcl + logger.info(f" * ACL to {parent_record.record_type}; {parent_record.title}; " + f"{record_parent_vertex.uid}") + if acl_content.is_admin: + logger.info(f" . Is Admin") + if acl_content.belongs_to: + logger.info(f" . Belongs") + else: + logger.info(f" . Is Remote user") + + if acl_content.rotation_settings is None: + logger.error(f" . There are no rotation settings!") + else: + if (acl_content.rotation_settings.schedule is None + or acl_content.rotation_settings.schedule == ""): + logger.info(f" . No Schedule") + else: + logger.info(f" . Schedule = {acl_content.rotation_settings.get_schedule()}") + + if (acl_content.rotation_settings.pwd_complexity is None + or acl_content.rotation_settings.pwd_complexity == ""): + logger.info(f" . No Password Complexity") + else: + key_bytes = record.record_key + logger.info(f" . Password Complexity = " + f"{acl_content.rotation_settings.get_pwd_complexity(key_bytes)}") + logger.info(f" . Disabled = {acl_content.rotation_settings.disabled}") + logger.info(f" . NOOP = {acl_content.rotation_settings.noop}") + logger.info(f" . SaaS Config Records = {acl_content.rotation_settings.saas_record_uid_list}") + + elif record.record_type == PAM_USER: + logger.error(f" * PAM User has NO acl!!!!!!") + + link_edge = record_vertex.get_edge(record_parent_vertex, EdgeType.LINK) + if link_edge is not None: + logger.info(f" * LINK to {parent_record.record_type}; {parent_record.title}; " + f"{record_parent_vertex.uid}") + else: + # This really should not happen + logger.error(f" Record does not have a parent record.") + logger.info("") + + record_child_vertices = record_vertex.has_vertices() + logger.info(f" Child Records") + if len(record_child_vertices) > 0: + for record_child_vertex in record_child_vertices: + child_record = vault.vault_data.load_record( + record_child_vertex.uid) + + if child_record is None: + logger.error(f" * Child record {record_child_vertex.uid} " + f"does not exists.") + continue + + acl_edge = record_child_vertex.get_edge(record_vertex, EdgeType.ACL) + link_edge = record_child_vertex.get_edge(record_vertex, EdgeType.LINK) + if acl_edge is not None: + acl_content = acl_edge.content_as_object(UserAcl) + logger.info(f" * ACL from {child_record.record_type}; {child_record.title}; " + f"{record_child_vertex.uid}") + if acl_content.is_admin: + logger.info(f" . Is Admin") + if acl_content.belongs_to: + logger.info(f" . Belongs") + else: + logger.info(f" . Is Remote user") + elif link_edge is not None: + logger.info(f" * LINK from {child_record.record_type}; {child_record.title}; " + "{record_child_vertex.uid}") + else: + for edge in record_vertex.edges: + logger.info(f" * {edge.edge_type}?") + + else: + # This is OK + logger.error(f" Record does not have any children.") + logger.info("") + + else: + logger.error(f"Cannot find record in record linking.") + + # Only PAM User and PAM Machine can have services and tasks. + # This is really only Windows machines. + if record.record_type == PAM_USER or record.record_type == PAM_MACHINE: + + # Get the user to service/task vertex. + user_service_vertex = user_service.dag.get_vertex(record_uid) + + if user_service_vertex is not None: + + # If the record is a PAM User + if record.record_type == PAM_USER: + + user_results = { + "is_task": [], + "is_service": [] + } + + # Get a list of all the resources the user is the username/password on service/task. + for us_machine_vertex in user_service.get_resource_vertices(record_uid): + + # Get the resource record + us_machine_record = ( + vault.vault_data.load_record(us_machine_vertex.uid)) + + acl = user_service.get_acl(us_machine_vertex.uid, user_service_vertex.uid) + for attr in ["is_task", "is_service"]: + value = getattr(acl, attr) + if value is True: + + # If the resource record does not exist. + if us_machine_record is None: + + # Default the title to Unknown (in red). + # See if we have an infrastructure vertex with this record UID. + # If we do have it, use the title inside the first vertex's data content. + title = "Unknown" + infra_resource_vertices = infra.dag.search_content( + {"record_uid": us_machine_vertex.uid}) + if len(infra_resource_vertices) > 0: + infra_resource_vertex = infra_resource_vertices[0] + if infra_resource_vertex.has_data is True: + content = DiscoveryObject.get_discovery_object(infra_resource_vertex) + title = content.title + + user_results[attr].append(f" * Record {us_machine_vertex.uid}, " + f"{title} does not exists.") + + # Record exists; just use information from the record. + else: + user_results[attr].append(f" * {us_machine_record.title}, " + f"{us_machine_vertex.uid}") + + logger.info(f"Service on Machines") + if len(user_results["is_service"]) > 0: + for service in user_results["is_service"]: + logger.info(service) + else: + logger.info(" PAM User is not used for any services.") + logger.info("") + + logger.info(f"Scheduled Tasks on Machines") + if len(user_results["is_task"]) > 0: + for task in user_results["is_task"]: + logger.info(task) + else: + logger.info(" PAM User is not used for any scheduled tasks.") + logger.info("") + + # If the record is a PAM Machine + else: + user_results = { + "is_task": [], + "is_service": [] + } + + # Get the users that are used for tasks/services on this machine. + for us_user_vertex in user_service.get_user_vertices(record_uid): + + us_user_record = vault.vault_data.load_record( + us_user_vertex.uid) + acl = user_service.get_acl(user_service_vertex.uid, us_user_vertex.uid) + for attr in ["is_task", "is_service"]: + value = getattr(acl, attr) + if value is True: + + # If the user record does not exist. + if us_user_record is None: + + # Default the title to Unknown (in red). + # See if we have an infrastructure vertex with this record UID. + # If we do have it, use the title inside the first vertex's data content. + title = "Unknown" + infra_resource_vertices = infra.dag.search_content( + {"record_uid": us_user_vertex.uid}) + if len(infra_resource_vertices) > 0: + infra_resource_vertex = infra_resource_vertices[0] + if infra_resource_vertex.has_data is True: + content = DiscoveryObject.get_discovery_object(infra_resource_vertex) + title = content.title + + user_results[attr].append(f" * Record {us_user_vertex.uid}, " + f"{title} does not exists.") + + # Record exists; just use information from the record. + else: + user_results[attr].append(f" * {us_user_record.title}, " + f"{us_user_vertex.uid}") + + logger.info(f"Users that are used for Services") + if len(user_results["is_service"]) > 0: + for service in user_results["is_service"]: + logger.info(service) + else: + logger.info(" Machine does not use any non-builtin users for services.") + logger.info("") + + logger.info(f"Users that are used for Scheduled Tasks") + if len(user_results["is_task"]) > 0: + for task in user_results["is_task"]: + logger.info(task) + else: + logger.info(" Machine does not use any non-builtin users for scheduled tasks.") + logger.info("") + else: + logger.error(f"There are no services or schedule tasks associated with this record.") + logger.info("") + try: + if len(discovery_vertices) == 0: + logger.error(f"Could not find any discovery infrastructure vertices for " + f"{record.record_uid}") + elif len(discovery_vertices) > 0: + + if len(discovery_vertices) > 1: + logger.error(f"Found multiple vertices with the record UID of " + f"{record.record_uid}") + for vertex in discovery_vertices: + logger.info(f" * Infrastructure Vertex UID: {vertex.uid}") + logger.info("") + + discovery_vertex = discovery_vertices[0] + content = DiscoveryObject.get_discovery_object(discovery_vertex) + + missing_since = "NA" + if content.missing_since_ts is not None: + missing_since = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(content.missing_since_ts)) + + logger.info(f"Discovery Object Information") + logger.info(f" Vertex UID: {content.uid}") + logger.info(f" Object ID: {content.id}") + logger.info(f" Record UID: {content.record_uid}") + logger.info(f" Parent Record UID: {content.parent_record_uid}") + logger.info(f" Shared Folder UID: {content.shared_folder_uid}") + logger.info(f" Record Type: {content.record_type}") + logger.info(f" Object Type: {content.object_type_value}") + logger.info(f" Ignore Object: {content.ignore_object}") + logger.info(f" Rule Engine Result: {content.action_rules_result}") + logger.info(f" Name: {content.name}") + logger.info(f" Generated Title: {content.title}") + logger.info(f" Generated Description: {content.description}") + logger.info(f" Missing Since: {missing_since}") + logger.info(f" Discovery Notes:") + for note in content.notes: + logger.info(f" * {note}") + if content.error is not None: + logger.error(f" Error: {content.error}") + if content.stacktrace is not None: + logger.error(f" Stack Trace:") + logger.error(f"{content.stacktrace}") + logger.info("") + logger.info(f"Record Type Specifics") + + if record.record_type == PAM_USER: + logger.info(f" User: {content.item.user}") + logger.info(f" DN: {content.item.dn}") + logger.info(f" Database: {content.item.database}") + logger.info(f" Active: {content.item.active}") + logger.info(f" Expired: {content.item.expired}") + logger.info(f" Source: {content.item.source}") + elif record.record_type == PAM_MACHINE: + logger.info(f" Host: {content.item.host}") + logger.info(f" IP: {content.item.ip}") + logger.info(f" Port: {content.item.port}") + logger.info(f" Operating System: {content.item.os}") + logger.info(f" Provider Region: {content.item.provider_region}") + logger.info(f" Provider Group: {content.item.provider_group}") + logger.info(f" Is the Gateway: {content.item.is_gateway}") + logger.info(f" Allows Admin: {content.item.allows_admin}") + logger.info(f" Admin Reason: {content.item.admin_reason}") + logger.info("") + # If facts are not set, inside discover may not have been performed for the machine. + if content.item.facts.id is not None and content.item.facts.name is not None: + logger.info(f" Machine Name: {content.item.facts.name}") + logger.info(f" Machine ID: {content.item.facts.id.machine_id}") + logger.info(f" Product ID: {content.item.facts.id.product_id}") + logger.info(f" Board Serial: {content.item.facts.id.board_serial}") + logger.info(f" Directories:") + if content.item.facts.directories is not None and len(content.item.facts.directories) > 0: + for directory in content.item.facts.directories: + logger.info(f" * Directory Domain: {directory.domain}") + logger.info(f" Software: {directory.software}") + logger.info(f" Login Format: {directory.login_format}") + else: + logger.info(" Machines is not using any directories.") + + logger.info("") + logger.info(f" Services (Non Builtin Users):") + if len(content.item.facts.services) > 0: + for service in content.item.facts.services: + logger.info(f" * {service.name} = {service.user}") + else: + logger.info(" Machines has no services that are using non-builtin users.") + + logger.info(f" Scheduled Tasks (Non Builtin Users)") + if len(content.item.facts.tasks) > 0: + for task in content.item.facts.tasks: + logger.info(f" * {task.name} = {task.user}") + else: + logger.info(" Machines has no schedules tasks that are using non-builtin users.") + + logger.info(f" IIS Pools (Non Builtin Users)") + if len(content.item.facts.iis_pools) > 0: + for iis_pool in content.item.facts.iis_pools: + logger.info(f" * {iis_pool.name} = {iis_pool.user}") + else: + logger.info(" Machines has no IIS Pools that are using non-builtin users.") + else: + logger.error(f" Machine facts are not set. Discover inside may not have been " + f"performed.") + elif record.record_type == PAM_DATABASE: + logger.info(f" Host: {content.item.host}") + logger.info(f" IP: {content.item.ip}") + logger.info(f" Port: {content.item.port}") + logger.info(f" Database Type: {content.item.type}") + logger.info(f" Database: {content.item.database}") + logger.info(f" Use SSL: {content.item.use_ssl}") + logger.info(f" Provider Region: {content.item.provider_region}") + logger.info(f" Provider Group: {content.item.provider_group}") + logger.info(f" Allows Admin: {content.item.allows_admin}") + logger.info(f" Admin Reason: {content.item.admin_reason}") + elif record.record_type == PAM_DIRECTORY: + logger.info(f" Host: {content.item.host}") + logger.info(f" IP: {content.item.ip}") + logger.info(f" Port: {content.item.port}") + logger.info(f" Directory Type: {content.item.type}") + logger.info(f" Use SSL: {content.item.use_ssl}") + logger.info(f" Provider Region: {content.item.provider_region}") + logger.info(f" Provider Group: {content.item.provider_group}") + logger.info(f" Allows Admin: {content.item.allows_admin}") + logger.info(f" Admin Reason: {content.item.admin_reason}") + else: + for k, v in content.item: + logger.info(f" {k}: {v}") + + # Configuration records do not belong to other record; don't show. + if record.version != 6: + logger.info("") + logger.info(f"Belongs To Vertices (Parents)") + vertices = discovery_vertex.belongs_to_vertices() + for vertex in vertices: + try: + content = DiscoveryObject.get_discovery_object(vertex) + logger.info(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = discovery_vertex.get_edge(vertex, edge_type=edge_type) + if edge is not None: + logger.info(f" . {edge_type}, active: {edge.active}") + except Exception as err: + logger.error(f"Could not get belongs to information: {err}") + + if len(vertices) == 0: + logger.error(f" Does not belong to anyone") + + print("") + logger.info(f"Vertices Belonging To (Children)") + vertices = discovery_vertex.has_vertices() + for vertex in vertices: + try: + content = DiscoveryObject.get_discovery_object(vertex) + logger.info(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = vertex.get_edge(discovery_vertex, edge_type=edge_type) + if edge is not None: + logger.info(f" . {edge_type}, active: {edge.active}") + except Exception as err: + logger.error(f"Could not get belonging to information: {err}") + if len(vertices) == 0: + logger.error(f" Does not have any children.") + + logger.info("") + else: + logger.error(f"Could not find infrastructure vertex.") + except Exception as err: + logger.error(f"Could not get information on infrastructure: {err}") diff --git a/keepercli-package/src/keepercli/commands/pam/debug/debug_link.py b/keepercli-package/src/keepercli/commands/pam/debug/debug_link.py new file mode 100644 index 00000000..8839d09a --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/debug/debug_link.py @@ -0,0 +1,67 @@ +import argparse + +from ....params import KeeperParams +from keepersdk.helpers.keeper_dag.constants import PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from ..discovery.__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from .... import api + +logger = api.get_logger() + +class PAMDebugLinkCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='pam action debug link') + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action debug link') + PAMDebugLinkCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID.') + parser.add_argument('--configuration-uid', "-c", required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--resource-uid', '-r', required=True, dest='resource_uid', action='store', + help='Resource record UID.') + parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', + help='GraphSync debug level. Default is 0', type=int, default=0) + + def execute(self, context: KeeperParams, **kwargs): + gateway = kwargs.get("gateway") + resource_uid = kwargs.get("resource_uid") + debug_level = int(kwargs.get("debug_level", 0)) + + logger.info("") + + configuration_uid = kwargs.get('configuration_uid') + + gateway_context = GatewayContext.from_gateway(context=context, + gateway=gateway, + configuration_uid=configuration_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + record_link = RecordLink(record=gateway_context.configuration, + context=context, + logger=logger, + debug_level=debug_level) + + resource_record = context.vault.vault_data.load_record(resource_uid) + if resource_record is None: + logger.error(f"The parent record does not exists.") + return + + if resource_record.record_type not in [PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY]: + logger.error(f"The resource record type, {resource_record.record_type} " + f"is not allowed.") + return + + try: + record_link.belongs_to(resource_uid, gateway_context.configuration_uid, ) + record_link.save() + logger.info(f"Added link between '{resource_uid}' and " + f"{gateway_context.configuration_uid}") + except Exception as err: + logger.error(f"Could not add LINK: {err}") + raise err diff --git a/keepercli-package/src/keepercli/commands/pam/debug/debug_rs.py b/keepercli-package/src/keepercli/commands/pam/debug/debug_rs.py new file mode 100644 index 00000000..07b4df41 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/debug/debug_rs.py @@ -0,0 +1,218 @@ +import argparse +import re + +from ....params import KeeperParams +from keepersdk.helpers.keeper_dag.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from keepersdk.helpers.keeper_dag.dag_types import UserAcl, UserAclRotationSettings +from keepersdk.proto import router_pb2 +from ..discovery.__init__ import PAMGatewayActionDiscoverCommandBase +from .... import api +from ....helpers import router_utils +from keepersdk import utils + +logger = api.get_logger() + +class PAMDebugRotationSettingsCommand(PAMGatewayActionDiscoverCommandBase): + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action debug rotation') + PAMDebugRotationSettingsCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--user-record-uid', '-i', required=True, dest='user_record_uid', action='store', + help='PAM user record UID.') + parser.add_argument('--configuration-record-uid', '-c', required=False, + dest='configuration_record_uid', action='store', help='PAM configuration record UID.') + parser.add_argument('--resource-record-uid', '-r', required=False, + dest='resource_record_uid', action='store', help='PAM resource record UID.') + parser.add_argument('--noop', required=False, dest='noop', action='store_true', + help='User is part of a No Operation.') + parser.add_argument('--force', required=False, dest='force', action='store_true', + help='Force reset of the rotation settings.') + parser.add_argument('--dry-run', required=False, dest='dry_run', action='store_true', + help='Do not create or update anything.') + + def execute(self, context: KeeperParams, **kwargs): + user_record_uid = kwargs.get("user_record_uid") + resource_record_uid = kwargs.get("resource_record_uid") + configuration_record_uid = kwargs.get("configuration_record_uid") + noop = kwargs.get("noop", False) + force = kwargs.get("force", False) + dry_run = kwargs.get("dry_run", False) + vault = context.vault + + logger.info("") + + user_record = vault.vault_data.load_record(user_record_uid) + if user_record is None: + logger.error(f"The PAM user record does not exists.") + return + + if user_record.record_type != PAM_USER: + logger.error(f"The PAM user record is a {PAM_USER}. " + f"The record is {user_record.record_type}") + return + + record_rotation = params.record_rotation_cache.get(user_record_uid) + if record_rotation is None: + logger.warning(f"The protobuf rotation settings are missing. Attempting to create.") + + if configuration_record_uid is None: + logger.error(f"Cannot determine PAM configuration, please set the " + f"-c, --configuration-record-uid parameter for this command.") + return + + configuration_record = vault.vault_data.load_record(configuration_record_uid) + if configuration_record is None: + logger.error(f"Configuration record does not exists.") + return + + if re.search(r'^pam.+Configuration$', configuration_record.record_type) is None: + logger.error( + f"The configuration record is not a configuration record. " + f"It's {configuration_record.record_type} record.") + return + + if resource_record_uid is None: + while True: + yn = input("The resource record UID was not set. " + "This user does not belongs to a machine, database, or directory; " + "It's an IAM, Azure, or Domain Controller user? [Y/N]").lower() + if yn == "n": + logger.error(f"Since a resource is needed, please set --resource-record-uid, -r " + f"parameter for the this command.") + return + elif yn == "y": + break + + if resource_record_uid is not None: + + resource_record = vault.vault_data.load_record(resource_record_uid) + if resource_record is None: + logger.error(f"The resource record does not exists.") + return + + if resource_record.record_type not in [PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY]: + logger.error(f"The resource is NOT a " + f"{PAM_MACHINE}, {PAM_DATABASE}, or {PAM_DIRECTORY} record. " + f"It's a {resource_record.record_type}.") + return + + parent_uid = resource_record_uid or configuration_record_uid + + # Create rotation settings for the pamUser. + rq = router_pb2.RouterRecordRotationRequest() + rq.recordUid = utils.base64_url_encode(user_record_uid) + rq.revision = 0 + rq.configurationUid = utils.base64_url_encode(configuration_record_uid) + rq.resourceUid = utils.base64_url_encode(parent_uid) + rq.schedule = '' + rq.pwdComplexity = b'' + rq.disabled = False + + if not dry_run: + router_utils.router_set_record_rotation_information(context, rq) + + context.sync_data = True + vault.sync_down() + + record_rotation = params.record_rotation_cache.get(user_record_uid) + if record_rotation is None: + logger.error(f"Protobuf rotation settings did not create.") + return + else: + logger.info(f"DRY RUN: Would have created the protobuf rotation settings.") + record_rotation = { + "configuration_uid": configuration_record_uid, + "resource_uid": resource_record_uid + } + + configuration_record_uid = record_rotation.get("configuration_uid") + if configuration_record_uid is None: + logger.error(f"Record does not have the PAM Configuration set.") + return + + logger.info(f"Configuration Record UID: {configuration_record_uid}") + + configuration_record = vault.vault_data.load_record(configuration_record_uid) + if configuration_record is None: + logger.error(f"Configuration record does not exists.") + return + + resource_record_uid = record_rotation.get("resource_uid") + if resource_record_uid is not None: + + logger.info(f"Resource Record UID: {resource_record_uid}") + + resource_record = vault.vault_data.load_record(resource_record_uid) + if resource_record is None: + logger.error(f"The resource record does not exists.") + return + + if resource_record.record_type not in [PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY]: + logger.error(f"The resource is a {PAM_MACHINE}, {PAM_DATABASE}, or {PAM_DIRECTORY} record. " + f"It's a {resource_record.record_type}.") + return + + record_link = RecordLink(record=configuration_record, context=context) + + parent_uid = resource_record_uid or configuration_record_uid + parent_vertex = record_link.get_record_link(parent_uid) + if parent_vertex is None: + parent_type = "configuration" + if resource_record_uid is not None: + parent_type = "resource" + logger.error(f"Could not find the parent linking vertex for the {parent_type}.") + return + + logger.info(f"User Record UID: {user_record_uid}") + + user_vertex = record_link.get_record_link(user_record_uid) + if user_vertex is None: + logger.warning(f"The user vertex is missing; creating.") + record_link.dag.add_vertex(uid=user_record_uid) + + user_acl = record_link.get_acl(user_record_uid, parent_uid) + if user_acl is None: + logger.warning(f"No ACL exists between the user and the parent; creating.") + user_acl = UserAcl.default() + user_acl.belongs_to = True + + logger.info("") + if user_acl.rotation_settings is not None: + if (force is False and ( + user_acl.rotation_settings.schedule != "" + or user_acl.rotation_settings.pwd_complexity != "" + or (user_acl.rotation_settings.saas_record_uid_list is not None + and len(user_acl.rotation_settings.saas_record_uid_list) != 0))): + logger.error(f"{user_acl.model_dump_json(indent=4)}") + logger.error(f"Rotation settings exist in graph, use --force to reset.") + return + + # Reset the rotation settings. + user_acl.rotation_settings = UserAclRotationSettings() + user_acl.rotation_settings.noop = noop + if resource_record_uid is None: + user_acl.is_iam_user = True + + # Connect the user to the parent (configuration or resource) + record_link.belongs_to(user_record_uid, parent_uid, acl=user_acl) + + # If parent is not a configuration, make sure there is a LINK from the resource to the configuration. + if parent_uid != configuration_record_uid: + if record_link.get_parent_record_uid(parent_uid) is None: + logger.warning(f"Resource record has no LINK to configuration record; " + f"creating.") + record_link.belongs_to(configuration_record_uid, parent_uid) + + if not dry_run: + record_link.save() + + logger.info(f"{user_acl.model_dump_json(indent=4)}") + logger.info(f"Updated the ACL for the user.") + else: + logger.info(f"DRY RUN: Would have created this ACL.") + logger.info(f"{user_acl.model_dump_json(indent=4)}") diff --git a/keepercli-package/src/keepercli/commands/pam/debug/debug_vertex.py b/keepercli-package/src/keepercli/commands/pam/debug/debug_vertex.py new file mode 100644 index 00000000..971a64c7 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/debug/debug_vertex.py @@ -0,0 +1,199 @@ +import argparse +import re +import time + +from keepersdk.helpers.keeper_dag.user_service import Infrastructure + +from ....params import KeeperParams +from keepersdk.helpers.keeper_dag.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.helpers.keeper_dag.dag_types import DiscoveryObject +from keepersdk.helpers.keeper_dag.dag import EdgeType +from ..discovery.__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from .... import api + +logger = api.get_logger() + +class PAMDebugVertexCommand(PAMGatewayActionDiscoverCommandBase): + type_name_map = { + PAM_USER: "PAM User", + PAM_MACHINE: "PAM Machine", + PAM_DATABASE: "PAM Database", + PAM_DIRECTORY: "PAM Directory", + } + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action debug vertex') + PAMDebugVertexCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--vertex', '-i', required=True, dest='vertex_uid', action='store', + help='Vertex in infrastructure graph') + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + debug_level = kwargs.get("debug_level", False) + + configuration_uid = kwargs.get('configuration_uid') + + gateway_context = GatewayContext.from_gateway(context=context, + gateway=gateway, + configuration_uid=configuration_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + infra = Infrastructure(record=gateway_context.configuration, context=context, fail_on_corrupt=False, + debug_level=debug_level) + infra.load() + + vertex_uid = kwargs.get("vertex_uid") + vertex = infra.dag.get_vertex(vertex_uid) + if vertex is None: + logger.error(f"Could not find the vertex in the graph for {gateway}.") + return + + content = DiscoveryObject.get_discovery_object(vertex) + missing_since = "NA" + if content.missing_since_ts is not None: + missing_since = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(content.missing_since_ts)) + + logger.info(f"Discovery Object Information") + logger.info(f"Vertex UID: {content.uid}") + logger.info(f"Object ID: {content.id}") + logger.info(f"Record UID: {content.record_uid}") + logger.info(f"Parent Record UID: {content.parent_record_uid}") + logger.info(f"Shared Folder UID: {content.shared_folder_uid}") + logger.info(f"Record Type: {content.record_type}") + logger.info(f"Object Type: {content.object_type_value}") + logger.info(f"Ignore Object: {content.ignore_object}") + logger.info(f"Rule Engine Result: {content.action_rules_result}") + logger.info(f"Name: {content.name}") + logger.info(f"Generated Title: {content.title}") + logger.info(f"Generated Description: {content.description}") + logger.info(f"Missing Since: {missing_since}") + logger.info(f"Discovery Notes:") + for note in content.notes: + logger.info(f" * {note}") + if content.error is not None: + logger.error(f" Error: {content.error}") + if content.stacktrace is not None: + logger.error(f" Stack Trace:") + logger.error(f"{content.stacktrace}") + logger.info("") + logger.info(f"Record Type Specifics") + + if content.record_type == PAM_USER: + logger.info(f"User: {content.item.user}") + logger.info(f"DN: {content.item.dn}") + logger.info(f"Database: {content.item.database}") + logger.info(f"Active: {content.item.active}") + logger.info(f"Expired: {content.item.expired}") + logger.info(f"Source: {content.item.source}") + elif content.record_type == PAM_MACHINE: + logger.info(f"Host: {content.item.host}") + logger.info(f"IP: {content.item.ip}") + logger.info(f"Port: {content.item.port}") + logger.info(f"Operating System: {content.item.os}") + logger.info(f"Provider Region: {content.item.provider_region}") + logger.info(f"Provider Group: {content.item.provider_group}") + logger.info(f"Is the Gateway: {content.item.is_gateway}") + logger.info(f"Allows Admin: {content.item.allows_admin}") + logger.info(f"Admin Reason: {content.item.admin_reason}") + logger.info("") + # If facts are not set, inside discover may not have been performed for the machine. + if content.item.facts.id is not None and content.item.facts.name is not None: + logger.info(f"Machine Name: {content.item.facts.name}") + logger.info(f"Machine ID: {content.item.facts.id.machine_id}") + logger.info(f"Product ID: {content.item.facts.id.product_id}") + logger.info(f"Board Serial: {content.item.facts.id.board_serial}") + logger.info(f"Directories:") + if content.item.facts.directories is not None and len(content.item.facts.directories) > 0: + for directory in content.item.facts.directories: + logger.info(f" * Directory Domain: {directory.domain}") + logger.info(f" Software: {directory.software}") + logger.info(f" Login Format: {directory.login_format}") + else: + logger.info(" Machines is not using any directories.") + + logger.info("") + logger.info(f"Services (Non Builtin Users):") + if len(content.item.facts.services) > 0: + for service in content.item.facts.services: + logger.info(f" * {service.name} = {service.user}") + else: + logger.info(" Machines has no services that are using non-builtin users.") + + logger.info(f"Scheduled Tasks (Non Builtin Users)") + if len(content.item.facts.tasks) > 0: + for task in content.item.facts.tasks: + logger.info(f" * {task.name} = {task.user}") + else: + logger.info(" Machines has no schedules tasks that are using non-builtin users.") + + logger.info(f"IIS Pools (Non Builtin Users)") + if len(content.item.facts.iis_pools) > 0: + for iis_pool in content.item.facts.iis_pools: + logger.info(f" * {iis_pool.name} = {iis_pool.user}") + else: + logger.info(" Machines has no IIS Pools that are using non-builtin users.") + + else: + logger.error(f" Machine facts are not set. Discover inside may not have been " + f"performed.") + elif content.record_type == PAM_DATABASE: + logger.info(f"Host: {content.item.host}") + logger.info(f"IP: {content.item.ip}") + logger.info(f"Port: {content.item.port}") + logger.info(f"Database Type: {content.item.type}") + logger.info(f"Database: {content.item.database}") + logger.info(f"Use SSL: {content.item.use_ssl}") + logger.info(f"Provider Region: {content.item.provider_region}") + logger.info(f"Provider Group: {content.item.provider_group}") + logger.info(f"Allows Admin: {content.item.allows_admin}") + logger.info(f"Admin Reason: {content.item.admin_reason}") + elif content.record_type == PAM_DIRECTORY: + logger.info(f"Host: {content.item.host}") + logger.info(f"IP: {content.item.ip}") + logger.info(f"Port: {content.item.port}") + logger.info(f"Directory Type: {content.item.type}") + logger.info(f"Use SSL: {content.item.use_ssl}") + logger.info(f"Provider Region: {content.item.provider_region}") + logger.info(f"Provider Group: {content.item.provider_group}") + logger.info(f"Allows Admin: {content.item.allows_admin}") + logger.info(f"Admin Reason: {content.item.admin_reason}") + + logger.info("") + logger.info(f"Belongs To Vertices (Parents)") + vertices = vertex.belongs_to_vertices() + for vertex in vertices: + content = DiscoveryObject.get_discovery_object(vertex) + logger.info(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = vertex.get_edge(vertex, edge_type=edge_type) + if edge is not None: + logger.info(f" . {edge_type}, active: {edge.active}") + + if len(vertices) == 0: + logger.error(f" Does not belong to anyone") + + logger.info("") + logger.info(f"Vertices Belonging To (Children)") + vertices = vertex.has_vertices() + for vertex in vertices: + content = DiscoveryObject.get_discovery_object(vertex) + logger.info(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = vertex.get_edge(vertex, edge_type=edge_type) + if edge is not None: + logger.info(f" . {edge_type}, active: {edge.active}") + if len(vertices) == 0: + logger.info(f" Does not have any children.") + + logger.info("") diff --git a/keepercli-package/src/keepercli/commands/pam/discovery/__init__.py b/keepercli-package/src/keepercli/commands/pam/discovery/__init__.py new file mode 100644 index 00000000..5f80e434 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/discovery/__init__.py @@ -0,0 +1,378 @@ + +import base64 +import json +import os +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + + +from .... import api +from ....commands import base +from ....helpers import router_utils +from ....helpers.gateway_utils import get_all_gateways + + +from keepersdk.vault import vault_record, vault_online +from keepersdk.helpers.pam_config_facade import PamConfigurationRecordFacade +from keepersdk.helpers.keeper_dag.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.proto import pam_pb2, APIRequest_pb2 +from keepersdk import utils +from keepersdk.crypto import decrypt_aes_v2, encrypt_aes_v2 +from keepersdk.helpers.keeper_dag import dag_utils + +logger = api.get_logger() + + +class MultiConfigurationException(Exception): + """ + If the gateway has multiple configuration + """ + def __init__(self, items: List[Dict]): + super().__init__() + self.items = items + + def print_items(self): + for item in self.items: + record = item["configuration_record"] + logger.info(f" * {record.record_uid} - {record.title}") + + +class GatewayContext: + + """ + Context for a gateway and a configuration. + + In the configuration record, the gateway is selected. + This means multiple configuration can use the same gateway. + Commander is gateway centric, we need to treat gateway and configuration as a `primary key` + + Since we get the configuration record from the vault, go through each of them and see if that gateway + is only used by one configuration. + If it is, then that gateway and configuration pair are used. + If there are multiple configuration, we need to throw an MultiConfigurationException. + + """ + + def __init__(self, configuration: vault_record.KeeperRecord, facade: PamConfigurationRecordFacade, + gateway: pam_pb2.PAMController, application: vault_record.ApplicationRecord): + self.configuration = configuration + self.facade = facade + self.gateway = gateway + self.application = application + self._shared_folders = None + + @staticmethod + def all_gateways(vault: vault_online.VaultOnline): + return get_all_gateways(vault) + + @staticmethod + def find_gateway(vault: vault_online.VaultOnline, find_func: Callable, gateways: Optional[List] = None) \ + -> Tuple[Optional["GatewayContext"], Any]: + + """ + Populate the context from matching using the function passed in. + The function needs to return a non-None value to be considered a positive match. + + """ + + if gateways is None: + gateways = GatewayContext.all_gateways(vault) + + configuration_records = list() + for configuration_record in configuration_records: + payload = find_func( + configuration_record=configuration_record + ) + if payload is not None: + return GatewayContext.from_configuration_uid( + vault=vault, + configuration_uid=configuration_record.record_uid, + gateways=gateways + ), payload + + return None, None + + @staticmethod + def from_configuration_uid(vault: vault_online.VaultOnline, configuration_uid: str, gateways: Optional[List] = None) \ + -> Optional["GatewayContext"]: + + """ + Populate context using the configuration UID. + + From the configuration record, get the gateway from the settings. + + """ + + if gateways is None: + gateways = GatewayContext.all_gateways(vault) + + configuration_record = vault.vault_data.load_record(configuration_uid) + if not isinstance(configuration_record, vault_record.TypedRecord): + logger.error(f'PAM Configuration [{configuration_uid}] is not available.') + return None + + configuration_facade = PamConfigurationRecordFacade() + configuration_facade.record = configuration_record + + gateway_uid = configuration_facade.controller_uid + gateway = next((x for x in gateways + if utils.base64_url_encode(x.controllerUid) == gateway_uid), + None) + + if gateway is None: + return None + + application_id = utils.base64_url_encode(gateway.applicationUid) + application = vault.vault_data.load_record(application_id) + + return GatewayContext( + configuration=configuration_record, + facade=configuration_facade, + gateway=gateway, + application=application + ) + + @staticmethod + def from_gateway(vault: vault_online.VaultOnline, gateway: str, configuration_uid: Optional[str] = None) \ + -> Optional["GatewayContext"]: + + """ + Populate context use the gateway, and optional configuration UID. + + This will scan all configuration to find which ones use this gateway. + If there are multiple ones, a MultiConfigurationException is thrown. + If there is only one gateway, then that gateway is used. + + """ + # Get all the PAM configuration records in the Vault; not Application + configuration_records = list(vault.vault_data.find_records("pam.*Configuration")) + + if configuration_uid: + logger.debug(f"find the gateway with configuration record {configuration_uid}") + + # You get this if the user has not setup any PAM related records. + if len(configuration_records) == 0: + logger.error(f"Cannot find any PAM configuration records in the Vault") + return None + + all_gateways = get_all_gateways(vault) + found_items = [] + for configuration_record in configuration_records: + + logger.debug(f"checking configuration record {configuration_record.title}") + + # Load the configuration record and get the gateway_uid from the facade. + configuration_record = vault.vault_data.load_record(configuration_record.record_uid) + configuration_facade = PamConfigurationRecordFacade() + configuration_facade.record = configuration_record + + configuration_gateway_uid = configuration_facade.controller_uid + if configuration_gateway_uid is None: + logger.debug(f" * configuration {configuration_record.title} does not have a gateway set, skipping.") + continue + + # Get the gateway for this configuration + found_gateway = next((x for x in all_gateways if utils.base64_url_encode(x.controllerUid) == + configuration_gateway_uid), None) + if found_gateway is None: + logger.debug(f" * configuration does not use desired gateway") + continue + + # If the configuration_uid was passed in, and we find it, just set the found items to this + # configuration and stop checking for more. + if configuration_uid is not None and configuration_uid == configuration_record.record_uid: + logger.debug(f" * configuration record uses this gateway and matches desire configuration, " + "skipping the rest") + found_items = [{ + "configuration_facade": configuration_facade, + "configuration_record": configuration_record, + "gateway": found_gateway + }] + break + + if (utils.base64_url_encode(found_gateway.controllerUid) == gateway or + found_gateway.controllerName.lower() == gateway.lower()): + logger.debug(f" * configuration record uses this gateway") + found_items.append({ + "configuration_facade": configuration_facade, + "configuration_record": configuration_record, + "gateway": found_gateway + }) + + if len(found_items) > 1: + logger.debug(f"found {len(found_items)} configurations using this gateway") + raise MultiConfigurationException( + items=found_items + ) + + if len(found_items) == 1: + found_gateway = found_items[0]["gateway"] + configuration_record = found_items[0]["configuration_record"] + configuration_facade = found_items[0]["configuration_facade"] + + application_id = utils.base64_url_encode(found_gateway.applicationUid) + application = vault.vault_data.load_record(application_id) + if application is None: + logger.debug(f"cannot find application for gateway {gateway}, skipping.") + + if (utils.base64_url_encode(found_gateway.controllerUid) == gateway or + found_gateway.controllerName.lower() == gateway.lower()): + return GatewayContext( + configuration=configuration_record, + facade=configuration_facade, + gateway=found_gateway, + application=application + ) + + return None + + @property + def gateway_uid(self) -> str: + return utils.base64_url_encode(self.gateway.controllerUid) + + @property + def configuration_uid(self) -> str: + return self.configuration.record_uid + + @property + def gateway_name(self) -> str: + return self.gateway.controllerName + + @property + def default_shared_folder_uid(self) -> str: + return self.facade.folder_uid + + def is_gateway(self, request_gateway: str) -> bool: + if request_gateway is None or self.gateway_name is None: + return False + return (request_gateway == utils.base64_url_encode(self.gateway.controllerUid) or + request_gateway.lower() == self.gateway_name.lower()) + + def get_shared_folders(self, vault: vault_online.VaultOnline) -> List[dict]: + if self._shared_folders is None: + self._shared_folders = [] + application_uid = utils.base64_url_encode(self.gateway.applicationUid) + app_info = vault.vault_data.load_record(application_uid) + for info in app_info: + if info.shares is None: + continue + for shared in info.shares: + uid_str = utils.base64_url_encode(shared.secretUid) + shared_type = APIRequest_pb2.ApplicationShareType.Name(shared.shareType) + if shared_type == 'SHARE_TYPE_FOLDER': + if uid_str not in vault.vault_data._shared_folders: + continue + cached_shared_folder = vault.vault_data._shared_folders[uid_str] + self._shared_folders.append({ + "uid": uid_str, + "name": cached_shared_folder.get('name_unencrypted'), + "folder": cached_shared_folder + }) + return self._shared_folders + + def decrypt(self, cipher_base64: bytes) -> dict: + ciphertext = base64.b64decode(cipher_base64.decode()) + return json.loads(decrypt_aes_v2(ciphertext, self.configuration.record_key)) + + def encrypt(self, data: dict) -> str: + json_data = json.dumps(data) + ciphertext = encrypt_aes_v2(json_data.encode(), self.configuration.record_key) + return base64.b64encode(ciphertext).decode() + + def encrypt_str(self, data: Union[bytes, str]) -> str: + if isinstance(data, str): + data = data.encode() + ciphertext = encrypt_aes_v2(data, self.configuration.record_key) + return base64.b64encode(ciphertext).decode() + + @staticmethod + def get_configuration_records(vault: vault_online.VaultOnline) -> List[vault_record.KeeperRecord]: + + """ + Get PAM configuration records. + + The default it to find all the record version 6 records. + If the environment variable `PAM_RECORD_TYPE_MATCH` is set to a true value, the search will use both record + versions 3 and 6, and then check the record type. + """ + + configuration_list = [] + if dag_utils.value_to_boolean(os.environ.get("PAM_RECORD_TYPE_MATCH")): + for record in list(vault.vault_data.find_records(record_version=iter([3, 6]))): + if re.search(r"pam.+Configuration", record.record_type): + configuration_list.append(record) + else: + configuration_list = list(vault.vault_data.find_records(record_version=6)) + return configuration_list + + +class PAMGatewayActionDiscoverCommandBase(base.ArgparseCommand): + + """ + The discover command base. + + Contains static methods to get the configuration record, get and update the discovery store. These are methods + used by multiple discover actions. + """ + + # If the discovery data field does not exist, or the field contains no values, use the template to init the + # field. + + STORE_LABEL = "discoveryKey" + FIELD_MAPPING = { + "pamHostname": { + "type": "dict", + "field_input": [ + {"key": "hostName", "prompt": "Hostname"}, + {"key": "port", "prompt": "Port"} + ], + "field_format": [ + {"key": "hostName", "label": "Hostname"}, + {"key": "port", "label": "Port"}, + ] + }, + "alternativeIPs": { + "type": "csv", + }, + "privatePEMKey": { + "type": "multiline", + }, + "operatingSystem": { + "type": "choice", + "values": ["linux", "macos", "windows", "cisco_ios_xe"] + } + } + + type_name_map = { + PAM_USER: "PAM Users", + PAM_MACHINE: "PAM Machines", + PAM_DATABASE: "PAM Databases", + PAM_DIRECTORY: "PAM Directories", + } + + @staticmethod + def get_response_data(router_response: dict) -> Optional[dict]: + + if router_response is None: + return None + + response = router_response.get("response") + logger.debug(f"Router Response: {response}") + payload = router_utils.get_response_payload(router_response) + return payload.get("data") + + @staticmethod + def _p(msg): + return msg + + @staticmethod + def _n(record_type): + return PAMGatewayActionDiscoverCommandBase.type_name_map.get(record_type, "PAM Configuration") + + + +def multi_conf_msg(gateway: str, err: MultiConfigurationException): + logger.info(f"Found multiple configuration records for gateway {gateway}.") + logger.info("Please use the --configuration-uid parameter to select the configuration.") + logger.info("Available configurations are: ") + err.print_items() \ No newline at end of file diff --git a/keepercli-package/src/keepercli/commands/pam/discovery/discover.py b/keepercli-package/src/keepercli/commands/pam/discovery/discover.py new file mode 100644 index 00000000..331a6137 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/discovery/discover.py @@ -0,0 +1,2194 @@ +import argparse +import json +import os +import sys +from typing import Any, Dict, List, Optional, Tuple +from pydantic import BaseModel + +from keepersdk import crypto, utils + +from ... import base +from ....params import KeeperParams +from ....helpers import router_utils +from .... import api +from .__init__ import GatewayContext, MultiConfigurationException, multi_conf_msg, PAMGatewayActionDiscoverCommandBase +from ..pam_dto import GatewayAction, GatewayActionDiscoverJobStartInputs, GatewayActionDiscoverJobStart, GatewayActionDiscoverJobRemoveInputs, GatewayActionDiscoverJobRemove +from .rule_commands import PAMGatewayActionDiscoverRuleAddCommand, PAMGatewayActionDiscoverRuleListCommand, PAMGatewayActionDiscoverRuleRemoveCommand, PAMGatewayActionDiscoverRuleUpdateCommand + +from keepersdk.helpers.pam_user_record_facade import PamUserRecordFacade +from keepersdk.helpers.keeper_dag.jobs import Jobs +from keepersdk.helpers.keeper_dag.dag_types import (CredentialBase, DiscoveryDelta, DiscoveryObject, JobItem, UserAcl, DirectoryInfo, + BulkRecordConvert, BulkRecordAdd, BulkRecordSuccess, BulkRecordFail, BulkProcessResults, NormalizedRecord, BulkRecordFail, PromptResult, + PromptActionEnum) +from keepersdk.helpers.keeper_dag.dag_vertex import DAGVertex +from keepersdk.helpers.keeper_dag.dag import DAG +from keepersdk.helpers.keeper_dag.dag_sort import sort_infra_vertices +from keepersdk.helpers.keeper_dag.constants import VERTICES_SORT_MAP, DIS_INFRA_GRAPH_ID, PAM_USER +from keepersdk.helpers.keeper_dag.infrastructure import Infrastructure +from keepersdk.helpers.keeper_dag.process import Process, NoDiscoveryDataException, QuitException +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from keepersdk.vault import record_types, vault_extensions, vault_online, vault_record +from keepersdk.proto import pam_pb2, record_pb2, router_pb2 + +logger = api.get_logger() + + +class PAMGatewayActionDiscoverJobStatusCommand(PAMGatewayActionDiscoverCommandBase): + """ + Get the status of discovery jobs. + + If no parameters are given, it will check all gateways for discovery job status. + + """ + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action discover status') + PAMGatewayActionDiscoverJobStatusCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=False, dest='gateway', action='store', + help='Show only discovery jobs from a specific gateway.') + parser.add_argument('--job-id', '-j', required=False, dest='job_id', action='store', + help='Detailed information for a specific discovery job.') + parser.add_argument('--history', required=False, dest='show_history', action='store_true', + help='Show history') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID is using --history') + + @staticmethod + def print_job_table(jobs: List[Dict], + max_gateway_name: int, + show_history: bool = False): + + """ + Print jobs in a table. + + This method takes a list of dictionary item which contains the cooked job information. + + """ + + logger.info("") + logger.info(f"{'Job ID'.ljust(14, ' ')} " + f"{'Gateway Name'.ljust(max_gateway_name, ' ')} " + f"{'Gateway UID'.ljust(22, ' ')} " + f"{'Configuration UID'.ljust(22, ' ')} " + f"{'Status'.ljust(12, ' ')} " + f"{'Resource UID'.ljust(22, ' ')} " + f"{'Started'.ljust(19, ' ')} " + f"{'Completed'.ljust(19, ' ')} " + f"{'Duration'.ljust(19, ' ')} " + f"") + + logger.info(f"{''.ljust(14, '=')} " + f"{''.ljust(max_gateway_name, '=')} " + f"{''.ljust(22, '=')} " + f"{''.ljust(22, '=')} " + f"{''.ljust(12, '=')} " + f"{''.ljust(22, '=')} " + f"{''.ljust(19, '=')} " + f"{''.ljust(19, '=')} " + f"{''.ljust(19, '=')}") + + completed_jobs = [] + running_jobs = [] + failed_jobs = [] + + for job in jobs: + job_id = job['job_id'] + if job['status'] == "COMPLETE": + completed_jobs.append(job_id) + elif job['status'] == "RUNNING": + running_jobs.append(job_id) + elif job['status'] == "FAILED": + failed_jobs.append(job_id) + logger.info(f"{job_id} " + f"{job['gateway'].ljust(max_gateway_name, ' ')} " + f"{job['gateway_uid']} " + f"{job['configuration_uid']} " + f"{job['status'].ljust(12, ' ')} " + f"{(job.get('resource_uid') or 'NA').ljust(22, ' ')} " + f"{(job.get('start_ts_str') or 'NA').ljust(19, ' ')} " + f"{(job.get('end_ts_str') or 'NA').ljust(19, ' ')} " + f"{(job.get('duration') or 'NA').ljust(19, ' ')} " + f"") + + if len(completed_jobs) > 0 and show_history is False: + logger.info("") + if len(completed_jobs) == 1: + logger.info(f"There is one COMPLETED job. To process, use the following command.") + else: + logger.info(f"There are {len(completed_jobs)} COMPLETED jobs. " + "To process, use one of the the following commands.") + for job_id in completed_jobs: + logger.info(f" pam action discover process -j {job_id}") + + if len(running_jobs) > 0 and show_history is False: + logger.info("") + if len(running_jobs) == 1: + logger.info(f"There is one RUNNING job. " + "If there is a problem, use the following command to cancel/remove the job.") + else: + logger.info(f"There are {len(running_jobs)} RUNNING jobs. " + "If there is a problem, use one of the following commands to cancel/remove the job.") + for job_id in running_jobs: + logger.info(f" pam action discover remove -j {job_id}") + + if len(failed_jobs) > 0 and show_history is False: + logger.info("") + if len(failed_jobs) == 1: + logger.info(f"There is one FAILED job. " + "If there is a problem, use the following command to get more information.") + else: + logger.info(f"There are {len(failed_jobs)} FAILED jobs. " + "If there is a problem, use one of the following commands to get more information.") + for job_id in failed_jobs: + logger.info(f" pam action discover status -j {job_id}") + logger.info("") + if len(failed_jobs) == 1: + logger.info(f"To remove the job, use the following command.") + else: + logger.info(f"To remove the FAILED job, use one of the following commands.") + for job_id in failed_jobs: + logger.info(f" pam action discover remove -j {job_id}") + + logger.info("") + + @staticmethod + def print_job_detail(vault: vault_online.VaultOnline, + all_gateways: List, + job_id: str): + + def _find_job(configuration_record) -> Optional[Dict]: + jobs_obj = Jobs(record=configuration_record) + job_item = jobs_obj.get_job(job_id) + if job_item is not None: + return { + "jobs": jobs_obj, + } + return None + + gateway_context, payload = GatewayContext.find_gateway(vault=vault, + find_func=_find_job, + gateways=all_gateways) + + if gateway_context is not None: + jobs = payload["jobs"] + job = jobs.get_job(job_id) + infra = Infrastructure(record=gateway_context.configuration) + + status = "RUNNING" + if job.end_ts is not None and not job.error: + if job.success is None: + status = "CANCELLED" + else: + status = "COMPLETE" + elif job.error: + status = "FAILED" + + logger.info("") + logger.info(f"Job ID: {job.job_id}") + logger.info(f"Sync Point: {job.sync_point}") + logger.info(f"Gateway Name: {gateway_context.gateway_name}") + logger.info(f"Gateway UID: {gateway_context.gateway_uid}") + logger.info(f"Configuration UID: {gateway_context.configuration_uid}") + logger.info(f"Status: {status}") + logger.info(f"Resource UID: {job.resource_uid or 'NA'}") + logger.info(f"Started: {job.start_ts_str}") + logger.info(f"Completed: {job.end_ts_str}") + logger.info(f"Duration: {job.duration_sec_str}") + + # If it failed, show the error and stacktrace. + if status == "FAILED": + logger.info("") + logger.info(f"Gateway Error:") + logger.info(f"{job.error}") + logger.info("") + logger.info(f"Gateway Stacktrace:") + logger.info(f"{job.stacktrace}") + # If it finished, show information about what was discovered. + elif job.end_ts is not None: + + try: + infra.load(sync_point=0) + logger.info("") + delta_json = job.delta + if delta_json is not None: + delta = DiscoveryDelta.model_validate(delta_json) + logger.info(f"Added - {len(delta.added)} count") + for item in delta.added: + vertex = infra.dag.get_vertex(item.uid) + if vertex is None or vertex.active is False or vertex.has_data is False: + logger.debug("added: vertex is none, inactive or has no data") + continue + discovery_object = DiscoveryObject.get_discovery_object(vertex) + logger.info(f" * {discovery_object.description}") + + logger.info("") + logger.info(f"Changed - {len(delta.changed)} count") + for item in delta.changed: + vertex = infra.dag.get_vertex(item.uid) + if vertex is None or vertex.active is False or vertex.has_data is False: + logger.debug("changed: vertex is none, inactive or has no data") + continue + discovery_object = DiscoveryObject.get_discovery_object(vertex) + logger.info(f" * {discovery_object.description}") + if item.changes is None: + logger.info(f" no changed, may be a object not added in prior discoveries.") + else: + for key, value in item.changes.items(): + logger.info(f" - {key} = {value}") + + logger.info("") + logger.info(f"Deleted - {len(delta.deleted)} count") + for item in delta.deleted: + logger.info(f" * discovery vertex {item.uid}") + else: + logger.info(f"There are no available delta changes for this job.") + + except Exception as err: + logger.info(f"Could not load delta from infrastructure: {str(err)}") + logger.info("Fall back to raw graph.") + logger.info("") + dag = DAG(conn=infra.conn, record=infra.record, graph_id=DIS_INFRA_GRAPH_ID) + logger.info(dag.to_dot_raw(sync_point=job.sync_point, rank_dir="RL")) + + else: + logger.info(f"Could not find the gateway with job {job_id}.") + + def execute(self, context: KeeperParams, **kwargs): + if not context.vault: + raise base.CommandError('Vault is not initialized. Login to initialize the vault.') + + vault = context.vault + + # If this is set, only show status for this gateway and history for this gateway. + gateway_filter = kwargs.get("gateway") + + # If this is set, only show detailed information about this job. + job_id = kwargs.get("job_id") + + # Show the history for the gateway. + # gateway_filter needs to be set for + show_history = kwargs.get("show_history") + + # Get all the gateways here so we don't have to keep calling this method. + # It gets passed into find_gateway, and find_gateway will pass it around. + all_gateways = GatewayContext.all_gateways(vault) + + # If we are showing all gateways, disable show history. + # History is shown for a specific gateway. + if gateway_filter is None: + show_history = False + + # This is used to format the table. Start with a length of 12 characters for the gateway. + max_gateway_name = 12 + + # If we have a job id, only display information about the one job + if job_id: + self.print_job_detail(vault=vault, + all_gateways=all_gateways, + job_id=job_id) + + # Else show jobs in a table + else: + + # Based on parameters set by user, select specific jobs to be displayed. + selected_jobs = [] # type: List[Dict] + + # For each configuration/ gateway, we are going to get all jobs. + # We are going to query the gateway for any updated status. + + configuration_records = list(vault.vault_data.find_records("pam.*Configuration")) + for configuration_record in configuration_records: + + gateway_context = GatewayContext.from_configuration_uid( + vault=vault, + configuration_uid=configuration_record.record_uid, + gateways=all_gateways) + + if gateway_context is None: + continue + + # If we are using a gateway filter, and this gateway is not the one, then go onto the next conf/gateway. + if gateway_filter is not None and gateway_context.is_gateway(gateway_filter) is False: + continue + + # If the gateway name is longer that the prior, set the max length to this gateway's name. + if len(gateway_context.gateway_name) > max_gateway_name: + max_gateway_name = len(gateway_context.gateway_name) + + jobs = Jobs(record=configuration_record) + if show_history is True: + job_list = reversed(jobs.history) + else: + job_list = [] + if jobs.current_job is not None: + job_list = [jobs.current_job] + + for job_item in job_list: + job = job_item.model_dump() + job["status"] = "RUNNING" + if job_item.start_ts is not None: + job["start_ts_str"] = job_item.start_ts_str + if job_item.end_ts is not None: + job["end_ts_str"] = job_item.end_ts_str + job["status"] = "COMPLETE" + + job["duration"] = job_item.duration_sec_str + + job["gateway"] = gateway_context.gateway_name + job["gateway_uid"] = gateway_context.gateway_uid + job["configuration_uid"] = gateway_context.configuration_uid + + # This is needs for details + job["gateway_context"] = gateway_context + job["job_item"] = job_item + + if job_item.success is None and job_item.end_ts: + job["status"] = "CANCELLED" + elif job_item.success is False: + job["status"] = "FAILED" + + selected_jobs.append(job) + + if len(selected_jobs) == 0: + logger.info(f"There are no discovery jobs. Use 'pam action discover start' to start a " + f"discovery job.") + return + + self.print_job_table(jobs=selected_jobs, + max_gateway_name=max_gateway_name, + show_history=show_history) + + +class PAMGatewayActionDiscoverJobStartCommand(base.ArgparseCommand): + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action discover start') + PAMGatewayActionDiscoverJobStartCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--resource', '-r', required=False, dest='resource_uid', action='store', + help='UID of the resource record. Set to discover specific resource.') + + parser.add_argument('--lang', required=False, dest='language', action='store', default="en_US", + help='Language') + parser.add_argument('--include-machine-dir-users', required=False, dest='include_machine_dir_users', + action='store_false', default=True, help='Include directory users found on the machine.') + parser.add_argument('--inc-azure-aadds', required=False, dest='include_azure_aadds', + action='store_true', help='Include Azure Active Directory Domain Service.') + parser.add_argument('--skip-rules', required=False, dest='skip_rules', + action='store_true', help='Skip running the rule engine.') + parser.add_argument('--skip-machines', required=False, dest='skip_machines', + action='store_true', help='Skip discovering machines.') + parser.add_argument('--skip-databases', required=False, dest='skip_databases', + action='store_true', help='Skip discovering databases.') + parser.add_argument('--skip-directories', required=False, dest='skip_directories', + action='store_true', help='Skip discovering directories.') + parser.add_argument('--skip-cloud-users', required=False, dest='skip_cloud_users', + action='store_true', help='Skip discovering cloud users.') + + def execute(self, context: KeeperParams, **kwargs): + if not context.vault: + raise base.CommandError('Vault is not initialized. Login to initialize the vault.') + + vault = context.vault + + # Load the configuration record and get the gateway_uid from the facade. + gateway = kwargs.get('gateway') + gateway_context = None + try: + gateway_context = GatewayContext.from_gateway(vault=vault, + gateway=gateway, + configuration_uid=kwargs.get('configuration_uid')) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + except MultiConfigurationException as err: + multi_conf_msg(gateway, err) + return + + jobs = Jobs(record=gateway_context.configuration) + current_job_item = jobs.current_job + removed_prior_job = None + if current_job_item is not None: + if current_job_item.is_running is True: + logger.warning("A discovery job is currently running. Cannot start another until it is finished.") + logger.warning("To check the status, use the command 'pam action discover status'.") + logger.warning(f"To stop and remove the current job, use the command 'pam action discover remove -j {current_job_item.job_id}'.") + return + + logger.error(f"An active discovery job exists for this gateway.") + logger.info("") + status = PAMGatewayActionDiscoverJobStatusCommand() + status.execute(context=context) + logger.info("") + + yn = input("Do you wish to remove the active discovery job and run a new one [Y/N]> ").lower() + while True: + if yn[0] == "y": + jobs.cancel(current_job_item.job_id) + removed_prior_job = current_job_item.job_id + break + elif yn[0] == "n": + logger.error(f"Not starting a discovery job.") + return + + # Get the credentials passed in via the command line + credentials = [] + creds = kwargs.get('credentials') + if creds is not None: + for cred in creds: + parts = cred.split("|") + c = CredentialBase() + for item in parts: + kv = item.split("=") + if len(kv) != 2: + logger.error(f"A '--cred' is invalid. It does not have a value.") + return + if not hasattr(c, kv[0]): + logger.error(f"A '--cred' is invalid. The key '{kv[0]}' is invalid.") + return + if hasattr(c, kv[1]) == "": + logger.error(f"A '--cred' is invalid. The value is blank.") + return + setattr(c, kv[0], kv[1]) + credentials.append(c.model_dump()) + + # Get the credentials passed in via a credential file. + credential_files = kwargs.get('credential_file') + if credential_files is not None: + with open(credential_files, "r") as fh: + try: + creds = json.load(fh) + except FileNotFoundError: + logger.error(f"Could not find the file {credential_files}") + return + except json.JSONDecoder: + logger.error(f"The file {credential_files} is not valid JSON.") + return + except Exception as err: + logger.error(f"The JSON file {credential_files} could not be imported: {err}") + return + + if not isinstance(creds, list): + logger.error(f"Credential file is invalid. Structure is not an array.") + return + num = 1 + for obj in creds: + c = CredentialBase() + for key in obj: + if not hasattr(c, key): + logger.error(f"Object {num} has the invalid key {key}.") + return + setattr(c, key, obj[key]) + credentials.append(c.model_dump()) + + action_inputs = GatewayActionDiscoverJobStartInputs( + configuration_uid=gateway_context.configuration_uid, + resource_uid=kwargs.get('resource_uid'), + user_map=gateway_context.encrypt( + self.make_protobuf_user_map( + context=context, + gateway_context=gateway_context + ) + ), + + shared_folder_uid=gateway_context.default_shared_folder_uid, + languages=[kwargs.get('language')], + + # Settings + include_machine_dir_users=kwargs.get('include_machine_dir_users', True), + include_azure_aadds=kwargs.get('include_azure_aadds', False), + skip_rules=kwargs.get('skip_rules', False), + skip_machines=kwargs.get('skip_machines', False), + skip_databases=kwargs.get('skip_databases', False), + skip_directories=kwargs.get('skip_directories', False), + skip_cloud_users=kwargs.get('skip_cloud_users', False), + credentials=credentials + ) + + conversation_id = GatewayAction.generate_conversation_id() + router_response = router_utils.router_send_action_to_gateway( + context=context, + gateway_action=GatewayActionDiscoverJobStart( + inputs=action_inputs, + conversation_id=conversation_id), + message_type=pam_pb2.CMT_DISCOVERY, + is_streaming=False, + destination_gateway_uid_str=gateway_context.gateway_uid + ) + + data = self.get_response_data(router_response) + if data is None: + logger.error(f"The router returned a failure.") + return + + if "has been queued" in data.get("Response", ""): + + if removed_prior_job is None: + logger.info("The discovery job is currently running.") + else: + logger.info(f"Active discovery job {removed_prior_job} has been removed and new discovery job is running.") + logger.info(f"To check the status, use the command 'pam action discover status'.") + logger.info(f"To stop and remove the current job, use the command 'pam action discover remove -j '.") + else: + router_utils.print_router_response(router_response, "job_info", conversation_id, gateway_uid=gateway_context.gateway_uid) + + @staticmethod + def make_protobuf_user_map(vault: vault_online.VaultOnline, gateway_context: GatewayContext) -> List[dict]: + """ + Make a user map for PAM Users. + + The map is used to find existing records. + Since KSM cannot read the rotation settings using protobuf, + it cannot match a vault record to a discovered users. + This map will map a login/DN and parent UID to a record UID. + """ + + user_map = [] + for record in vault.vault_data.find_records("pamUser"): + user_record = vault.vault_data.load_record(record.record_uid) + user_facade = PamUserRecordFacade() + user_facade.record = user_record + + info = params.record_rotation_cache.get(user_record.record_uid) + if info is None: + continue + + # Make sure this user is part of this gateway. + if info.get("configuration_uid") != gateway_context.configuration_uid: + continue + + # If the user Admin Cred Record (i.e., parent) is blank, skip the mapping item + # This will be a UID string, not 16 bytes. + if info.get("resource_uid") is None or info.get("resource_uid") == "": + continue + + user_map.append({ + "user": user_facade.login if user_facade.login != "" else None, + "dn": user_facade.distinguishedName if user_facade.distinguishedName != "" else None, + "record_uid": user_record.record_uid, + "parent_record_uid": info.get("resource_uid") + }) + + logger.debug(f"found {len(user_map)} user map items") + + return user_map + +class PAMGatewayActionDiscoverJobRemoveCommand(PAMGatewayActionDiscoverCommandBase): + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action discover remove') + PAMGatewayActionDiscoverJobRemoveCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--job-id', '-j', required=True, dest='job_id', action='store', + help='Discovery job id.') + + def execute(self, context: KeeperParams, **kwargs): + if not context.vault: + raise base.CommandError('Vault is not initialized. Login to initialize the vault.') + + vault = context.vault + + job_id = kwargs.get("job_id") + + # Get all the gateways here so we don't have to keep calling this method. + # It gets passed into find_gateway, and find_gateway will pass it around. + all_gateways = GatewayContext.all_gateways(vault) + + def _find_job(configuration_record) -> Optional[Dict]: + jobs_obj = Jobs(record=configuration_record) + job_item = jobs_obj.get_job(job_id) + if job_item is not None: + return { + "jobs": jobs_obj, + } + return None + + gateway_context, payload = GatewayContext.find_gateway(vault=vault, + find_func=_find_job, + gateways=all_gateways) + + if gateway_context is not None: + jobs = payload["jobs"] + + try: + # First, cancel the running discovery job if it is running. + logger.debug("cancel job on the gateway, if running") + action_inputs = GatewayActionDiscoverJobRemoveInputs( + configuration_uid=gateway_context.configuration_uid, + job_id=job_id + ) + + conversation_id = GatewayAction.generate_conversation_id() + router_response = router_utils.router_send_action_to_gateway( + context=context, + gateway_action=GatewayActionDiscoverJobRemove( + inputs=action_inputs, + conversation_id=conversation_id), + message_type=pam_pb2.CMT_DISCOVERY, + is_streaming=False, + destination_gateway_uid_str=gateway_context.gateway_uid + ) + + data = self.get_response_data(router_response) + if data is None: + raise Exception("The router returned a failure.") + elif data.get("success") is False: + error = data.get("error") + raise Exception(f"Discovery job was not removed: {error}") + except Exception as err: + logger.debug(f"gateway return error removing discovery job: {err}") + + jobs.cancel(job_id) + jobs.close() + + logger.info(f"Discovery job has been removed or cancelled.") + return + + logger.error(f'Discovery job not found. Cannot get remove the job.') + return + + +# This is used for the admin user search +class AdminSearchResult(BaseModel): + record: Any + is_directory_user: bool + is_pam_user: bool + being_used: bool = False + + +class PAMGatewayActionDiscoverResultProcessCommand(PAMGatewayActionDiscoverCommandBase): + + EDITABLE = [ + "login", + "password", + "distinguishedName", + "alternativeIPs", + "database", + "privatePEMKey", + "connectDatabase", + "operatingSystem" + ] + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action discover process') + PAMGatewayActionDiscoverResultProcessCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--job-id', '-j', required=True, dest='job_id', action='store', + help='Discovery job to process.') + parser.add_argument('--add-all', required=False, dest='add_all', action='store_true', + help='Respond with ADD for all prompts.') + parser.add_argument('--preview', required=False, dest='do_preview', action='store_true', + help='Preview the results') + parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', + help='GraphSync debug level. Default is 0', type=int, default=0) + + + @staticmethod + def _is_directory_user(record_type: str) -> bool: + # pamAzureConfiguration has tenant users what are like a directory. + return (record_type == "pamDirectory" or + record_type == "pamAzureConfiguration") + + @staticmethod + def _get_shared_folder(vault: vault_online.VaultOnline, pad: str, gateway_context: GatewayContext) -> str: + while True: + shared_folders = gateway_context.get_shared_folders(vault) + index = 0 + for folder in shared_folders: + logger.info(f"{pad}* {str(index+1)} - {folder.get('uid')} {folder.get('name')}") + index += 1 + selected = input(f"{pad}Enter number of the shared folder>") + try: + return shared_folders[int(selected) - 1].get("uid") + except ValueError: + logger.error(f"{pad}Input was not a number.") + + @staticmethod + def get_field_values(record: vault_record.TypedRecord, field_type: str) -> List[Any]: + return next( + (f.value + for f in record.fields + if f.type == field_type), + None + ) + + def get_keys_by_record(self, context: KeeperParams, gateway_context: GatewayContext, + record: vault_record.TypedRecord) -> List[str]: + """ + For the record, get the values of fields that are key for this record type. + + :param params: + :param gateway_context: + :param record: + :return: + """ + + key_field = Process.get_key_field(record.record_type) + keys = [] + if key_field == "host_port": + values = self.get_field_values(record, "pamHostname") # type: List[dict] + if len(values) == 0: + return [] + + host = values[0].get("hostName") + port = values[0].get("port") + if port is not None: + if host is not None: + keys.append(f"{host}:{port}".lower()) + + elif key_field == "host": + values = self.get_field_values(record, "pamHostname") + if len(values) == 0: + return [] + + host = values[0].get("hostName") + if host is not None: + keys.append(host.lower()) + + elif key_field == "user": + + # This is user protobuf values. + # We could make this also use record linking if we stop using protobuf. + + record_rotation = params.record_rotation_cache.get(record.record_uid) + if record_rotation is not None: + controller_uid = record_rotation.get("configuration_uid") + if controller_uid is None or controller_uid != gateway_context.configuration_uid: + return [] + + resource_uid = record_rotation.get("resource_uid") + # If the resource uid is None, the Admin Cred Record has not been set. + if resource_uid is None: + return [] + + values = self.get_field_values(record, "login") + if len(values) == 0: + return [] + + keys.append(f"{resource_uid}:{values[0]}".lower()) + + return keys + + @staticmethod + def _record_lookup(record_uid: str, context:KeeperParams) -> Optional[NormalizedRecord]: + + """ + Get the record from the Vault, normalize it, and return it. + + Since common code is using this method we want to flatten/abstract the KeeperRecord/TypedRecord. + """ + + vault = context.vault + record = vault.vault_data.load_record(record_uid) + if record is None: + return None + + normalized_record = NormalizedRecord( + record_uid=record.record_uid, + record_type=record.record_type, + title=record.title, + ) + for field in record.fields: + normalized_record.fields.append( + record_types.RecordField( + type=field.type, + label=field.label, + value=field.value, + ) + ) + if record.custom is not None: + for field in record.custom: + normalized_record.fields.append( + record_types.RecordField( + type=field.type, + label=field.label, + value=field.value, + ) + ) + return normalized_record + + def _build_record_cache(self, context: KeeperParams, gateway_context: GatewayContext) -> dict: + + """ + Make a lookup cache for all the records. + + This is used to flag discovered items as existing if the record has already been added. This is used to + prevent duplicate records being added. + """ + + logger.debug(f"building the PAM record cache") + + # Make a cache of existing record by the criteria per record type + cache = { + "pamUser": {}, + "pamMachine": {}, + "pamDirectory": {}, + "pamDatabase": {} + } + + vault = context.vault + + # Set all the PAM Records + records = vault.vault_data.find_records(criteria="pam*", record_type=None, record_version=None) + for record in records: + # If the record type is not part of the cache, skip the record + if record.record_type not in cache: + continue + + # Load the full record + record = vault.vault_data.load_record(record.record_uid) + + cache_keys = self.get_keys_by_record( + context=context, + gateway_context=gateway_context, + record=record + ) + if len(cache_keys) == 0: + continue + + for cache_key in cache_keys: + cache[record.record_type][cache_key] = record.record_uid + + return cache + + def _edit_record(self, content: DiscoveryObject, pad: str, editable: List[str]) -> bool: + + edit_label = input(f"{pad}Enter 'title' or the name of the Label to edit, RETURN to cancel> ") + + # Just pressing return exits the edit + if edit_label == "": + return False + + # If the "title" is entered, then edit the title of the record. + if edit_label.lower() == "title": + new_title = input(f"{pad}Enter new title> ") + content.title = new_title + + # If a field label is entered, and it's in the list of editable fields, then allow the user to edit. + elif edit_label in editable: + new_value = None + if edit_label in self.FIELD_MAPPING: + type_hint = self.FIELD_MAPPING[edit_label].get("type") + if type_hint == "dict": + field_input_format = self.FIELD_MAPPING[edit_label].get("field_input") + new_value = {} + for field in field_input_format: + new_value[field.get('key')] = input(f"{pad}Enter {field_input_format.get('prompt')} value> ") + elif type_hint == "csv": + new_value = input(f"{pad}Enter {edit_label} values, separate with a comma > ") + new_values = map(str.strip, new_value.split(',')) + new_value = "\n".join(new_values) + elif type_hint == "multiline": + logger.info(f"{pad}Enter multilines of text or a path, on the first line, " + "to a file that contains the value.") + logger.info(f"{pad}To end, type 'END' at the start of a new line. You can paste text.") + new_value = "" + first_line = True + while True: + line = input(f"> ").rstrip() + if line == "END": + break + + # If this is the first line, check if line is a path to a file. + if first_line: + try: + test_file = line.strip() + logger.debug(f"is first line, check for file path for '{test_file}'") + if os.path.exists(test_file): + with open(test_file, "r") as fh: + new_value = fh.read() + fh.close() + break + else: + logger.debug(f"first line is not a file path") + except Exception as err: + logger.debug(f"exception checking if file: {err}") + first_line = False + new_value += line + "\n" + elif type_hint == "choice": + + values = self.FIELD_MAPPING[edit_label].get("values") + text_values = [x for x in values] + new_value = input(f"{pad}Enter one of the follow values: {', '.join(text_values)}> ") + new_value = new_value.strip().lower() + if new_value not in values: + logger.error(f"{pad}The value {new_value} is not one of the values allowed.") + return False + else: + new_value = input(f"{pad}Enter new value, or path to a file that contains the value > ") + + # Is the value a path to a file, i.e., a private key file. + try: + if os.path.exists(new_value): + with open(new_value, "r") as fh: + new_value = fh.read() + fh.close() + except (Exception,): + pass + + for edit_field in content.fields: + if edit_field.label == edit_label: + edit_field.value = [new_value] + + # Else, the label they entered cannot be edited. + else: + logger.error(f"{pad}The field is not editable.") + return False + + return True + + @staticmethod + def _add_all_preprocess(vertex: DAGVertex, content: DiscoveryObject, parent_vertex: DAGVertex, + acl: Optional[UserAcl] = None) -> Optional[PromptResult]: + """ + This is client side check if we should skip prompting the user. + + The checks are + * A directory with the same domain already has a record. + + """ + + _ = vertex + _ = acl + + # Check if the directory for a domain exists. + # From the parent, find any directory objects. + # If they already have a record UID, don't prompt about this one. + # Once a directory for the domain exists, the user should not be prompted about this domain anymore. + if content.record_type == "pamDirectory": + for v in parent_vertex.has_vertices(): + other_content = DiscoveryObject.get_discovery_object(v) + if other_content.record_uid is not None and other_content.name == content.name: + return PromptResult(action=PromptActionEnum.SKIP) + return None + + def _prompt_display_fields(self, content: DiscoveryObject, pad: str) -> List[str]: + + editable = [] + for field in content.fields: + has_editable = False + if field.label in PAMGatewayActionDiscoverResultProcessCommand.EDITABLE: + editable.append(field.label) + has_editable = True + value = field.value + + # If there is a value, and it's not just [], also make sure the + if len(value) > 0 and value[0] is not None: + # PAM records will have only 1 item in the value array. + value = value[0] + if field.label in self.FIELD_MAPPING: + type_hint = self.FIELD_MAPPING[field.label].get("type") + formatted_value = [] + if type_hint == "dict": + field_input_format = self.FIELD_MAPPING[field.label].get("field_format") + for format_field in field_input_format: + formatted_value.append(f"{format_field.get('label')}: " + f"{value.get(format_field.get('key'))}") + elif type_hint == "csv": + formatted_value.append(", ".join(value.split("\n"))) + elif type_hint == "multiline": + formatted_value.append(value) + elif type_hint == "choice": + formatted_value.append(value) + value = ", ".join(formatted_value) + else: + if has_editable: + value = "MISSING" + else: + value = "None" + + rows = str(value).split("\n") + if len(rows) > 1: + value = rows[0] + f"... {len(rows)} rows." + + logger.info(f"{pad} " + f"Label: {field.label}, " + f"Type: {field.type}, " + f"Value: {value}") + + if len(content.notes) > 0: + logger.info("") + for note in content.notes: + logger.info(f"{pad}* {note}") + + return editable + + @staticmethod + def _prompt_display_relationships(vertex: DAGVertex, content: DiscoveryObject, pad: str): + + if vertex is None: + return + + if content.record_type == "pamUser": + belongs_to = [] + for v in vertex.belongs_to_vertices(): + resource_content = DiscoveryObject.get_discovery_object(v) + belongs_to.append(resource_content.name) + count = len(belongs_to) + logger.info("") + logger.info(f"{pad}This user is found on {count} resource{'s' if count > 1 else ''}") + + def _prompt(self, + content: DiscoveryObject, + acl: UserAcl, + vertex: Optional[DAGVertex] = None, + parent_vertex: Optional[DAGVertex] = None, + resource_has_admin: bool = True, + item_count: int = 0, + items_left: int = 0, + indent: int = 0, + block_auto_add: bool = False, + dry_run: bool = False, + add_all: bool = False, + vault: Optional[vault_online.VaultOnline] = None, + gateway_context: Optional[GatewayContext] = None + ) -> PromptResult: + + if gateway_context is None: + raise Exception("Context not set for processing the discovery results") + + parent_content = DiscoveryObject.get_discovery_object(parent_vertex) + + logger.info("") + + if block_auto_add: + add_all = False + + # If auto add is True, there are sometime we don't want to add the object. + # If we get a result, we want to return it. + # Skip the prompt. + if add_all is True and vertex is not None: + result = self._add_all_preprocess(vertex, content, parent_vertex, acl) + if result is not None: + return result + + # If the record type is a pamUser, then include parent description. + if content.record_type == "pamUser" and parent_vertex is not None: + parent_pad = "" + if indent - 1 > 0: + parent_pad = "".ljust(2 * indent, ' ') + + logger.info(f"{parent_pad}{parent_content.description}") + + pad = "" + if indent > 0: + pad = "".ljust(2 * indent, ' ') + + logger.info(f"{pad}{content.description}") + + show_current_object = True + while show_current_object: + logger.info(f"{pad}Record Title: {content.title}") + + logger.debug(f"Fields: {content.fields}") + + # Display the fields and return a list of fields are editable. + editable = self._prompt_display_fields(content=content, pad=pad) + if vertex is not None: + self._prompt_display_relationships(vertex=vertex, content=content, pad=pad) + + while True: + + shared_folder_uid = content.shared_folder_uid + if shared_folder_uid is None: + shared_folder_uid = gateway_context.default_shared_folder_uid + + count_prompt = "" + if item_count > 0: + count_prompt = f"[{item_count - items_left + 1}/{item_count}]" + edit_add_prompt = f"{count_prompt} " + if len(editable) > 0: + edit_add_prompt += f"(E)dit, " + + shared_folders = gateway_context.get_shared_folders(vault) + if dry_run is False: + if len(shared_folders) > 1: + folder_name = next((x['name'] + for x in shared_folders + if x['uid'] == shared_folder_uid), + None) + edit_add_prompt += f"(A)dd to {folder_name}, "\ + f"Add to (F)older, " + else: + if dry_run is False: + edit_add_prompt += f"(A)dd, " + prompt = f"{edit_add_prompt}(S)kip, (I)gnore, (Q)uit" + + command = "a" + if add_all is False: + command = input(f"{pad}{prompt}> ").lower() + if (command == "a" or command == "f") and dry_run is False: + + logger.info(f"{pad}Adding record to save queue.") + logger.info("") + + if command == "f": + shared_folder_uid = self._get_shared_folder(vault, pad, gateway_context) + + content.shared_folder_uid = shared_folder_uid + + # This happens when the record is a pamUser and parent resource record does not have an + # administrator. + # It's like the reverse of creating an admin after adding the resource. + # It would make this user the admin for the parent resource. + # This condition would be really rare, since to get the users, the resource would have to have an + # admin user. + if content.record_type == "pamUser" and resource_has_admin is False: + + logger.info(f"{parent_content.description} does not have an administrator.") + if (hasattr(parent_content.item, "admin_reason") and + parent_content.item.admin_reason is not None): + logger.info("") + logger.info(parent_content.item.admin_reason) + logger.info("") + + while True: + + yn = input("Do you want to make this user the administrator? [Y/N]> ").lower() + if yn == "": + continue + if yn[0] == "n": + break + if yn[0] == "y": + acl.is_admin = True + break + + return PromptResult( + action=PromptActionEnum.ADD, + acl=acl, + content=content + ) + + elif command == "e" and dry_run is False: + self._edit_record(content, pad, editable) + break + + elif command == "i": + + logger.info(f"{pad}Creating an ignore rule for record.") + return PromptResult( + action=PromptActionEnum.IGNORE, + acl=acl, + content=content + ) + + elif command == "s": + logger.info(f"{pad}Skipping record.") + + return PromptResult( + action=PromptActionEnum.SKIP, + acl=acl, + content=content + ) + elif command == "q": + raise QuitException() + logger.info("") + + return PromptResult( + action=PromptActionEnum.SKIP, + acl=acl, + content=content + ) + + def _find_user_record(self, + bulk_convert_records: List[BulkRecordConvert], + context: Optional[KeeperParams] = None, + gateway_context: Optional[GatewayContext] = None, + record_link: Optional[RecordLink] = None) -> Tuple[Optional[vault_record.TypedRecord], bool]: + + vault = context.vault + + # Get the latest records + vault.vault_data.sync_data = True + + # Make a list of all records in the shared folders. + # We will use this to check if a selected user is in the shared folders. + shared_record_uids = [] + for shared_folder in gateway_context.get_shared_folders(vault): + folder = shared_folder.get("folder") + if "records" in folder: + for record in folder["records"]: + shared_record_uids.append(record.get("record_uid")) + + # Make a list of record we are already converting so we don't show them again. + converting_list = [x.record_uid for x in bulk_convert_records] + + logger.debug(f"shared folders record uid {shared_record_uids}") + + while True: + user_search = input("Enter an user to search for [ENTER/RETURN to quit]> ") + if user_search == "": + logger.error(f"No search terms, not performing search.") + return None, False + + # Search for record with the search string. + # Currently, this only works with TypedRecord, version 3. + user_record = vault.vault_data.find_records( + criteria=user_search, + record_version=3, + record_type=None + ) + # If not record are returned by the search just return None, + if len(user_record) == 0: + logger.error(f"Could not find any records that contain the search text.") + return None, False + + # Find usable admin records. + admin_search_results = [] + for record in user_record: + + user_record = vault.vault_data.get_record(record.record_uid) + if user_record.record_type == "pamUser": + logger.debug(f"{record.record_uid} is a pamUser") + + # If we are already converting this pamUser record, then don't show it. + if record.record_uid in converting_list: + logger.debug(f"pamUser {user_record.title}, {user_record.record_uid} is being converted; " + "BAD for search") + admin_search_results.append( + AdminSearchResult( + record=user_record, + is_directory_user=False, + is_pam_user=True, + being_used=True + ) + ) + continue + + # Does the record exist in the gateway shared folder? + # We want to filter our other gateway's pamUser, or it will get overwhelming. + if user_record.record_uid not in shared_record_uids: + logger.debug(f"pamUser {record.title}, {user_record.record_uid} not in shared " + "folder, BAD for search") + continue + + # If the record does not exist in the record linking, it's orphaned; accept it + # If it does exist, then check if it belonged to a directory. + # Very unlikely a user that belongs to a database or another machine can be used. + + record_vertex = record_link.get_record_link(user_record.record_uid) + is_directory_user = False + if record_vertex is not None: + parent_record_uid = record_link.get_parent_record_uid(user_record.record_uid) + parent_record = vault.vault_data.get_record(parent_record_uid) + if parent_record is not None: + is_directory_user = self._is_directory_user(parent_record.record_type) + if not is_directory_user: + logger.debug(f"pamUser parent for {user_record.title}, " + "{user_record.record_uid} is not a directory; BAD for search") + continue + + logger.debug(f"pamUser {user_record.title}, {user_record.record_uid} is a directory user; " + "good for search") + + else: + logger.debug(f"pamUser {user_record.title}, {user_record.record_uid} does not a parent; " + "good for search") + else: + logger.debug(f"pamUser {user_record.title}, {user_record.record_uid} does not have record " + "linking vertex; good for search") + + admin_search_results.append( + AdminSearchResult( + record=user_record, + is_directory_user=is_directory_user, + is_pam_user=True, + being_used=False + ) + ) + + # Else this is a non-PAM record. + # Make sure it has a login, password, private key + else: + logger.debug(f"{record.record_uid} is NOT a pamUser") + login_field = next((x for x in record.fields if x.type == "login"), None) + password_field = next((x for x in record.fields if x.type == "password"), None) + private_key_field = next((x for x in record.fields if x.type == "keyPair"), None) + + if login_field is not None and (password_field is not None or private_key_field is not None): + admin_search_results.append( + AdminSearchResult( + record=record, + is_directory_user=False, + is_pam_user=False + ) + ) + logger.debug(f"{record.title} is has credentials, good for search") + else: + logger.debug(f"{record.title} is missing full credentials, BAD for search") + + # If all the users have been filtered out, then just return None + if len(admin_search_results) == 0: + logger.error(f"Could not find any available records.") + return None, False + + user_index = 1 + admin_search_results = sorted(admin_search_results, + key=lambda x: x.is_pam_user, + reverse=True) + + has_local_user = False + for admin_search_result in admin_search_results: + is_local_user = False + if admin_search_result.record.record_type != "pamUser": + has_local_user = True + is_local_user = True + + index_str = user_index + if admin_search_result.being_used: + index_str = "-" * len(str(index_str)) + + logger.info(f"[{index_str}] " + f"{'* ' if is_local_user is True else ''}" + f"{admin_search_result.record.title} " + f'{"(Directory User) " if admin_search_result.is_directory_user is True else ""}' + f'{"(Already taken)" if admin_search_result.being_used is True else ""}') + user_index += 1 + + if has_local_user: + logger.info(f"* Not a PAM User record. " + f"A PAM User would be generated from this record.") + + select = input("Enter line number of user record to use, enter/return to refine the search, " + f"or (Q) to quit search. > ").lower() + if select == "": + continue + elif select[0] == "q": + return None, False + else: + try: + selected = admin_search_results[int(select) - 1] + if selected.being_used: + logger.error(f"Cannot select a record that has already been taken. " + f"Another record is using this local user as its administrator.") + return None, False + admin_record = selected.record + return admin_record, selected.is_directory_user + except IndexError: + logger.error(f"Entered row index does not exists.") + continue + + return None, False + + @staticmethod + def _handle_admin_record_from_record(record: vault_record.TypedRecord, + content: DiscoveryObject, + context: Optional[KeeperParams] = None, + gateway_context: Optional[GatewayContext] = None) -> Optional[PromptResult]: + + vault = context.vault + + # Is this a pamUser record? + # Return the record UID and set its ACL to be the admin. + if record.record_type == "pamUser": + return PromptResult( + action=PromptActionEnum.ADD, + acl=UserAcl(is_admin=True), + record_uid=record.record_uid, + ) + + # If we are here, this was not a pamUser + # We need to duplicate the record. + # But confirm first + + # Get fields from the old record. + # Copy them into the fields. + login_field = next((x for x in record.fields if x.type == "login"), None) + password_field = next((x for x in record.fields if x.type == "password"), None) + private_key_field = next((x for x in record.fields if x.type == "keyPair"), None) + + content.set_field_value("login", login_field.value) + if password_field is not None: + content.set_field_value("password", password_field.value) + if private_key_field is not None: + value = private_key_field.value + if value is not None and len(value) > 0: + value = value[0] + private_key = value.get("privateKey") + if private_key is not None: + content.set_field_value("private_key", private_key) + + # Check if we have more than one shared folder. + # If we have one, confirm about adding the user. + # If multiple shared folders, allow user to select which one. + shared_folders = gateway_context.get_shared_folders(vault) + if len(shared_folders) == 0: + while True: + yn = input(f"Create a PAM User record from {record.title}? [Y/N]> ").lower() + if yn == "": + continue + elif yn[0] == "n": + return None + elif yn[0] == "y": + content.shared_folder_uid = gateway_context.default_shared_folder_uid + else: + folder_name = next((x['name'] + for x in shared_folders + if x['uid'] == gateway_context.default_shared_folder_uid), + None) + while True: + shared_folders = gateway_context.get_shared_folders(vault) + if len(shared_folders) > 1: + afq = input(f"(A)dd user to {folder_name}, " + f"Add user to (F)older, " + f"(Q)uit > ").lower() + else: + afq = input(f"(A)dd user, " + f"(Q)uit > ").lower() + + if afq == "": + continue + if afq[0] == "a": + content.shared_folder_uid = gateway_context.default_shared_folder_uid + break + elif afq[0] == "f": + shared_folder_uid = PAMGatewayActionDiscoverResultProcessCommand._get_shared_folder( + vault, "", gateway_context) + if shared_folder_uid is not None: + content.shared_folder_uid = shared_folder_uid + break + + return PromptResult( + action=PromptActionEnum.ADD, + acl=UserAcl(is_admin=True), + content=content, + note=f"This record replaces record {record.title} ({record.record_uid}). " + "The password on that record will not be rotated." + ) + + def _prompt_admin(self, + parent_vertex: DAGVertex, + content: DiscoveryObject, + acl: UserAcl, + bulk_convert_records: List[BulkRecordConvert], + indent: int = 0, + context: Optional[KeeperParams] = None, + gateway_context: Optional[GatewayContext] = None) -> Optional[PromptResult]: + + if content is None: + raise Exception("The admin content was not passed in to prompt the user.") + + parent_content = DiscoveryObject.get_discovery_object(parent_vertex) + + logger.info("") + vault = context.vault + while True: + + logger.info(f"{parent_content.description} does not have an administrator user.") + if hasattr(parent_content.item, "admin_reason") is True and parent_content.item.admin_reason is not None: + logger.info("") + logger.info(parent_content.item.admin_reason) + logger.info("") + + action = input("Would you like to (A)dd new administrator user, (F)ind an existing admin, or (S)kip add? > ").lower() + + if action == "": + continue + + if action[0] == 'a': + prompt_result = self._prompt( + vault=vault, + gateway_context=gateway_context, + vertex=None, + parent_vertex=parent_vertex, + content=content, + acl=acl, + indent=indent + 2, + block_auto_add=True + ) + login = content.get_field_value("login") + if login is None or login == "": + logger.error("A value is needed for the login field.") + continue + + logger.info(f"Adding admin record to save queue.") + return prompt_result + elif action[0] == 'f': + logger.info("") + record, is_directory_user = self._find_user_record(context=context, + gateway_context=gateway_context, + bulk_convert_records=bulk_convert_records) + if record is not None: + admin_prompt_result = self._handle_admin_record_from_record( + record=record, + content=content, + context=context, + gateway_context=gateway_context + ) + if admin_prompt_result is not None: + if admin_prompt_result.action == PromptActionEnum.ADD: + admin_prompt_result.is_directory_user = is_directory_user + logger.info(f"Adding admin record to save queue.") + return admin_prompt_result + elif action[0] == 's': + return PromptResult( + action=PromptActionEnum.SKIP + ) + logger.info("") + + @staticmethod + def _display_auto_add_results(bulk_add_records: List[BulkRecordAdd]): + + """ + Display the number of record created from rule engine ADD results and smart add function. + """ + + add_count = len(bulk_add_records) + if add_count > 0: + logger.info("") + logger.info(f"From the rules, automatically queued {add_count} " + f"record{'' if add_count == 1 else 's'} to be added.") + + @staticmethod + def _prompt_confirm_add(bulk_add_records: List[BulkRecordAdd]): + + """ + If we quit, we want to ask the user if they want to add record for discovery objects that they selected + for addition. + """ + + logger.info("") + count = len(bulk_add_records) + if count == 1: + msg = (f"There is 1 record queued to be added to your vault. " + f"Do you wish to add it? [Y/N]> ") + else: + msg = (f"There are {count} records queued to be added to your vault. " + f"Do you wish to add them? [Y/N]> ") + while True: + yn = input(msg).lower() + if yn == "": + continue + if yn[0] == "y": + return True + elif yn[0] == "n": + return False + logger.error("Did not get 'Y' or 'N'") + + @staticmethod + def _prepare_record(content: DiscoveryObject, context: Optional[KeeperParams] = None) -> Tuple[Any, str]: + + """ + Prepare the Vault record side. + + It's not created here. + It will be created at the end of the processing run in bulk. + We to build a record to get a record UID. + + :params content: The discovery object instance. + :params context: Optionally, it will contain information set from the run() method. + :returns: Returns an unsaved Keeper record instance. + """ + + # DEFINE V3 RECORD + + # Create an instance of a vault record to structure the data + record = vault_record.TypedRecord() + record.type_name = content.record_type + record.record_uid = utils.generate_uid() + record.record_key = utils.generate_aes_key() + record.title = content.title + for field in content.fields: + field_args = { + "field_type": field.type, + "field_value": field.value + } + if field.type != field.label: + field_args["field_label"] = field.label + record_field = vault_record.TypedField.new_field(**field_args) + record_field.required = field.required + record.fields.append(record_field) + + vault = context.vault + folder = vault.vault_data.get_folder(content.shared_folder_uid) + folder_key = None + if folder.folder_type == 'shared_folder_folder': + shared_folder_uid = folder.folder_scope_uid + elif folder.folder_type == 'shared_folder': + shared_folder_uid = folder.folder_uid + else: + shared_folder_uid = None + if shared_folder_uid and shared_folder_uid in vault.vault_data._shared_folders: + shared_folder = vault.vault_data.get_folder(shared_folder_uid) + folder_key = shared_folder.folder_key + + # DEFINE PROTOBUF FOR RECORD + + record_add_protobuf = record_pb2.RecordAdd() + record_add_protobuf.record_uid = utils.base64_url_decode(record.record_uid) + record_add_protobuf.record_key = crypto.encrypt_aes_v2(record.record_key, context.vault.keeper_auth.auth_context.data_key) + record_add_protobuf.client_modified_time = utils.current_milli_time() + record_add_protobuf.folder_type = record_pb2.user_folder + if folder: + record_add_protobuf.folder_uid = utils.base64_url_decode(folder.folder_uid) + if folder.folder_type == 'shared_folder': + record_add_protobuf.folder_type = record_pb2.shared_folder + elif folder.folder_type == 'shared_folder_folder': + record_add_protobuf.folder_type = record_pb2.shared_folder_folder + if folder_key: + record_add_protobuf.folder_key = crypto.encrypt_aes_v2(record.record_key, folder_key) + + data = vault_extensions.extract_typed_record_data(record, vault.vault_data.get_record_type_by_name(record.record_type)) + json_data = vault_extensions.get_padded_json_bytes(data) + record_add_protobuf.data = crypto.encrypt_aes_v2(json_data, record.record_key) + + if context.vault.keeper_auth.auth_context.enterprise_ec_public_key: + audit_data = vault_extensions.extract_audit_data(record) + if audit_data: + record_add_protobuf.audit.version = 0 + record_add_protobuf.audit.data = crypto.encrypt_ec( + json.dumps(audit_data).encode('utf-8'), context.vault.keeper_auth.auth_context.enterprise_ec_public_key) + + return record_add_protobuf, record.record_uid + + @classmethod + def _create_records(cls, bulk_add_records: List[BulkRecordAdd], context: KeeperParams, gateway_context: GatewayContext) -> BulkProcessResults: + + """ + Create Vault records, setup rotation settings, and configure the resource (if resource). + """ + + if len(bulk_add_records) == 1: + logger.info("Adding the record to the Vault ...") + else: + logger.info(f"Adding {len(bulk_add_records)} records to the Vault ...") + + build_process_results = BulkProcessResults() + + ############################################################################################################## + # + # STEP 1 - Batch add new records + + # Generate a list of RecordAdd instance. + # In BulkRecordAdd they will be the record instance. + record_add_list = [r.record for r in bulk_add_records] # type: List[record_pb2.RecordAdd] + + records_per_request = 999 + + add_results = [] + logger.debug("adding record in batches") + logger.info("batch record create: ", end="") + sys.stdout.flush() + while record_add_list: + logger.info(".", end="") + sys.stdout.flush() + logger.debug(f"* adding batch") + rq = record_pb2.RecordsAddRequest() + rq.records.extend(record_add_list[:records_per_request]) + record_add_list = record_add_list[records_per_request:] + rs = context.vault.keeper_auth.execute_auth_rest('vault/records_add', rq, response_type=record_pb2.RecordsModifyResponse) + add_results.extend(rs.records) + logger.info("") + sys.stdout.flush() + + logger.debug(f"add_result: {add_results}") + + if len(add_results) != len(bulk_add_records): + logger.debug(f"attempted to batch add {len(bulk_add_records)} record(s), " + f"only have {len(add_results)} results.") + + ############################################################################################################## + # + # STEP 2 - Add rotation settings for user and resource configuration for resources + # At this point the all the records have been created. + + # Keep track of each record we create a rotation for to avoid version problems, if there was a dup. + created_cache = [] + + # TODO: There is a bulk version of the following code, it's not live. + # Wait until live, then switch code to use that. + + # For the records passed in to be created. + logger.info("add rotation settings: ", end="") + sys.stdout.flush() + for bulk_record in bulk_add_records: + if bulk_record.record_uid in created_cache: + logger.debug(f"found a duplicate of record uid: {bulk_record.record_uid}") + continue + logger.info(".", end="") + sys.stdout.flush() + + # Grab the type Keeper record instance, and title from that record. + pb_add_record = bulk_record.record + title = bulk_record.title + + rotation_disabled = False + + # Find the result for this record. + result = None + for x in add_results: + logger.debug(f"{pb_add_record.record_uid} vs {x.record_uid}") + if pb_add_record.record_uid == x.record_uid: + result = x + break + + # If we didn't get a result, then don't add the rotation settings. + if result is None: + build_process_results.failure.append( + BulkRecordFail( + title=title, + error="No status on addition to Vault. Cannot determine if added or not." + ) + ) + logger.debug(f"Did not get a result when adding record {title}") + continue + + # Check if addition failed. If it did fail, don't add the rotation settings. + success = (result.status == record_pb2.RecordModifyResult.DESCRIPTOR.values_by_name['RS_SUCCESS'].number) + status = record_pb2.RecordModifyResult.DESCRIPTOR.values_by_number[result.status].name + + if not success: + build_process_results.failure.append( + BulkRecordFail( + title=title, + error=status + ) + ) + logger.debug(f"Had problem adding record for {title}: {status}") + continue + + # Only set the rotation setting if the record is a PAM User. + if bulk_record.record_type == PAM_USER: + + rq = router_pb2.RouterRecordRotationRequest() + rq.recordUid = utils.base64_url_decode(bulk_record.record_uid) + rq.revision = 0 + + # Set the gateway/configuration that this record should be connected. + rq.configurationUid = utils.base64_url_decode(gateway_context.configuration_uid) + + if bulk_record.parent_record_uid is not None: + rq.resourceUid = utils.base64_url_decode(bulk_record.parent_record_uid) + + # Right now, the schedule and password complexity are not set. This would be part of a rule engine. + rq.schedule = '' + rq.pwdComplexity = b'' + rq.disabled = rotation_disabled + + router_utils.router_set_record_rotation_information(context, rq) + + # This will be a resource. + # A LINK edge will be created between the configuration and resource. + # If there is an admin user, it will be set on the resource. + else: + + # This will create a LINK between the PAM Configuration and the resource. + rq = pam_pb2.PAMResourceConfig() + rq.recordUid = utils.base64_url_decode(bulk_record.record_uid) + rq.networkUid = utils.base64_url_decode(gateway_context.configuration_uid) + if bulk_record.admin_uid: + rq.adminUid = utils.base64_url_decode(bulk_record.admin_uid) + + router_utils.router_configure_resource(context, rq) + + created_cache.append(bulk_record.record_uid) + + build_process_results.success.append( + BulkRecordSuccess( + title=title, + record_uid=bulk_record.record_uid + ) + ) + logger.info("") + sys.stdout.flush() + + context.sync_data = True + + return build_process_results + + @classmethod + def _convert_records(cls, bulk_convert_records: List[BulkRecordConvert], context: KeeperParams, gateway_context: Optional[GatewayContext] = None): + + vault = context.vault + for bulk_convert_record in bulk_convert_records: + + record = vault.vault_data.load_record(bulk_convert_record.record_uid) + + rotation_disabled = False + + rq = router_pb2.RouterRecordRotationRequest() + rq.recordUid = utils.base64_url_decode(bulk_convert_record.record_uid) + + # We can't set the version to 0 if it's greater than 0, look up prior version. + record_rotation_revision = params.record_rotation_cache.get(bulk_convert_record.record_uid) + rq.revision = record_rotation_revision.get('revision') if record_rotation_revision else 0 + + # Set the gateway/configuration that this record should be connected. + rq.configurationUid = utils.base64_url_decode(gateway_context.configuration_uid) + + # Only set the resource if the record type is a PAM User. + # Machines, databases, and directories have a login/password in the record that indicates who the admin is. + if record.record_type == "pamUser" and bulk_convert_record.parent_record_uid is not None: + rq.resourceUid = utils.base64_url_decode(bulk_convert_record.parent_record_uid) + + # Right now, the schedule and password complexity are not set. This would be part of a rule engine. + rq.schedule = '' + rq.pwdComplexity = b'' + rq.disabled = rotation_disabled + + router_utils.router_set_record_rotation_information(context, rq) + + vault.sync_down(force=True) + + @staticmethod + def _get_directory_info(domain: str, + skip_users: bool = False, + context: Optional[KeeperParams] = None, + gateway_context: Optional[GatewayContext] = None) -> Optional[DirectoryInfo]: + """ + Get information about this record from the vault records. + + """ + + directory_info = DirectoryInfo() + + vault = context.vault + + # Find the all directory records, in for this gateway, that have a domain that matches what we are looking for. + for directory_record in vault.vault_data.find_records(criteria=None, record_type="pamDirectory", record_version=None): + directory_record = vault.vault_data.load_record(directory_record.record_uid) + + info = params.record_rotation_cache.get(directory_record.record_uid) + if info is None: + continue + + # Make sure this user is part of this gateway. + if info.get("configuration_uid") != gateway_context.configuration_uid: + continue + + domain_field = directory_record.get_typed_field("text", label="domainName") + if len(domain_field.value) == 0 or domain_field.value[0] == "": + continue + + if domain_field.value[0].lower() != domain.lower(): + continue + + directory_info.directory_record_uids.append(directory_record.record_uid) + + if directory_info.has_directories is True and skip_users is False: + + for user_record in vault.vault_data.find_records(criteria=None, record_type="pamUser", record_version=None): + info = params.record_rotation_cache.get(user_record.record_uid) + if info is None: + continue + + if info.get("resource_uid") is None or info.get("resource_uid") == "": + continue + + # If the user's belongs to a directory, and add it to the directory user list. + if info.get("resource_uid") in info.directory_record_uids: + directory_info.directory_user_record_uids.append(user_record.record_uid) + + return directory_info + + @staticmethod + def remove_job(context: KeeperParams, configuration_record: vault_record.KeeperRecord, job_id: str): + + try: + jobs = Jobs(record=configuration_record, context=context) + jobs.cancel(job_id) + logger.info(f"No items left to process. Removing completed discovery job.") + except Exception as err: + logger.error(err) + logger.error(f"No items left to process. Failed to remove discovery job.") + + def preview(self, job_item: JobItem, context: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0): + + sync_point = job_item.sync_point + infra = Infrastructure(record=gateway_context.configuration, + context=context, + logger=logger, + debug_level=debug_level) + infra.load(sync_point) + + configuration = None + try: + configuration = infra.get_root.has_vertices()[0] + except (Exception,): + logger.error(f"Could not find the configuration in the infrastructure graph. " + f"Has discovery been run for this gateway?") + + record_type_to_vertices_map = sort_infra_vertices(configuration) + + # ------------ + + def _print_resource(rt: str, rule_result: str): + + printed_something = False + + titles = { + "pamDirectory": "Directories", + "pamMachine": "Machines", + "pamDatabase": "Databases" + } + + for rv in record_type_to_vertices_map[rt]: + if not rv.active or not rv.has_data: + continue + user_vertices = rv.has_vertices() + + user_list = [] + for user_vertex in user_vertices: + if not user_vertex.active or not user_vertex.has_data: + continue + + user_content = DiscoveryObject.get_discovery_object(user_vertex) + if user_content.ignore_object or self._record_lookup(user_content.record_uid, context, gateway_context) is not None: + continue + + user_list.append(f" . {user_content.item.user} ({user_content.name})") + + c = DiscoveryObject.get_discovery_object(rv) + if len(user_list) == 0 and c.action_rules_result != rule_result or c.ignore_object: + continue + + has_record = "" + record_uid = c.record_uid + if record_uid is not None: + if self._record_lookup(record_uid, context, gateway_context): + has_record = f" (record exists: {record_uid})" + if len(user_list) == 0: + continue + else: + record_uid = None + + if c.action_rules_result != rule_result and not record_uid: + continue + + title = titles.get(c.record_type) + if title is not None: + logger.info(f" {(title)}") + titles[c.record_type] = None + + ip = "" + if c.item.host != c.item.ip: + ip = f" ({c.item.ip})" + + with_admin = "" + if c.admin_uid is not None and not record_uid: + with_admin = f" with Administrator UID {c.admin_uid}" + + logger.info(f" * {c.description}{ip}{with_admin}{has_record}") + printed_something = True + + if record_uid: + for user in user_list: + logger.info(user) + + return printed_something + + # ------------ + + def _print_cloud_user(rt: str, rule_result: str): + + title = "Users" + + for user_vertex in record_type_to_vertices_map[rt]: + if not user_vertex.active or not user_vertex.has_data: + continue + + uc = DiscoveryObject.get_discovery_object(user_vertex) + + if (uc.action_rules_result != rule_result + or uc.ignore_object + or self._record_lookup(uc.record_uid, context, gateway_context) is not None): + continue + + if title is not None: + logger.info(f" {(title)}") + title = None + + logger.info(f" * {uc.item.user} ({uc.name})") + + # ------------ + + logger.info("") + logger.info("Will Be Automatically Added") + nothing_to_print = True + for record_type in sorted(record_type_to_vertices_map, key=lambda i: VERTICES_SORT_MAP[i]['order']): + if record_type == "pamUser": + _print_cloud_user("pamUser", rule_result="add") + else: + if _print_resource(record_type, rule_result="add"): + nothing_to_print = False + if nothing_to_print: + logger.info(f" {'No records will be automatically added.'}") + + logger.info("") + logger.info("Will Be Prompted For") + nothing_to_print = True + for record_type in sorted(record_type_to_vertices_map, key=lambda i: VERTICES_SORT_MAP[i]['order']): + if record_type == "pamUser": + _print_cloud_user("pamUser", rule_result="prompt") + else: + if _print_resource(record_type, rule_result="prompt"): + nothing_to_print = False + if nothing_to_print: + logger.info(f" {'No items will be prompted.'}") + + logger.info("") + + def execute(self, context: KeeperParams, **kwargs): + if not context.vault: + raise base.CommandError('Vault is not initialized. Login to initialize the vault.') + + vault = context.vault + + do_preview = kwargs.get("do_preview", False) + job_id = kwargs.get("job_id") + add_all = kwargs.get("add_all", False) + debug_level = kwargs.get("debug_level", 0) + + all_gateways = GatewayContext.all_gateways(vault) + + configuration_records = GatewayContext.get_configuration_records(vault=vault) + for configuration_record in configuration_records: + + gateway_context = GatewayContext.from_configuration_uid(vault=vault, + configuration_uid=configuration_record.record_uid, + gateways=all_gateways) + if gateway_context is None: + continue + + record_cache = self._build_record_cache( + context=context, + gateway_context=gateway_context + ) + + # Get the current job. + # There can only be one active job. + # This will give us the sync point for the delta + jobs = Jobs(record=configuration_record, context=context, logger=logger, debug_level=debug_level) + job_item = jobs.current_job + if job_item is None: + continue + + # If this is not the job we are looking for, continue to the next gateway. + if job_item.job_id != job_id: + continue + + if job_item.end_ts is None: + logger.error(f'Discovery job is currently running. Cannot process.') + return + if job_item.success is False: + logger.error(f'Discovery job failed. Cannot process.') + return + + # Preview is a just a way to list which items will be added or prompted. + if do_preview: + self.preview( + job_item=job_item, + context=context, + gateway_context=gateway_context, + ) + return + + process = Process( + record=configuration_record, + job_id=job_item.job_id, + context=context, + logger=logger, + debug_level=debug_level, + ) + + if add_all: + logger.info(f"The ADD ALL flag has been set. All found items will be added.") + logger.info("") + + try: + results = process.run( + + # This method can get a record using the record UID + record_lookup_func=self._record_lookup, + + # Prompt user the about adding records + prompt_func=self._prompt, + + # Prompt user for an admin for a resource + prompt_admin_func=self._prompt_admin, + + # If quit, confirm if the user wants to add records + prompt_confirm_add_func=self._prompt_confirm_add, + + # Prepare records and place in queue; does not add record to vault + record_prepare_func=self._prepare_record, + + # Add record to the vault, protobuf, and record-linking graph + record_create_func=self._create_records, + + # This function will take existing pamUser record and make them belong to this + # gateway. + record_convert_func=self._convert_records, + + # A function to get directory users + directory_info_func=self._get_directory_info, + + # Pass method that will display auto added records. + auto_add_result_func=self._display_auto_add_results, + + # Provides a cache of the record key to record UID. + record_cache=record_cache, + + # Commander-specific context. + # Record link will be added by Process run as "record_link" + context=context, + gateway_context=gateway_context, + dry_run=False, + add_all=add_all, + ) + + logger.debug(f"Results: {results}") + + logger.info("") + if results is not None and results.num_results > 0: + logger.info(f"Successfully added {results.success_count} " + f"record{'s' if results.success_count != 1 else ''}.") + if results.has_failures: + logger.info(f"There were {results.failure_count} " + f"failure{'s' if results.failure_count != 1 else ''}.") + for fail in results.failure: + logger.info(f" * {fail.title}: {fail.error}") + + if process.no_items_left is True: + self.remove_job(context=context, configuration_record=configuration_record, job_id=job_id) + else: + logger.info(f"No records have been added.") + + except NoDiscoveryDataException: + logger.info(f"All items have been added for this discovery job.") + self.remove_job(context=context, configuration_record=configuration_record, job_id=job_id) + + except Exception as err: + logger.error(f"Could not process discovery: {err}") + raise err + + return + + logger.info(f"Could not find the Discovery job.") + logger.info("") + + +class PAMDiscoveryRuleCommand(base.GroupCommand): + + def __init__(self): + super().__init__('PAM Discovery Rule') + self.register_command(PAMGatewayActionDiscoverRuleAddCommand(), 'add', 'a') + self.register_command(PAMGatewayActionDiscoverRuleListCommand(), 'list', 'l') + self.register_command(PAMGatewayActionDiscoverRuleRemoveCommand(), 'remove', 'r') + self.register_command(PAMGatewayActionDiscoverRuleUpdateCommand(), 'update', 'u') + self.default_verb = 'list' diff --git a/keepercli-package/src/keepercli/commands/pam/discovery/rule_commands.py b/keepercli-package/src/keepercli/commands/pam/discovery/rule_commands.py new file mode 100644 index 00000000..22da42c7 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/discovery/rule_commands.py @@ -0,0 +1,436 @@ +import argparse +from typing import List + +from ....params import KeeperParams +from ....helpers import router_utils +from .... import api +from .__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from ..pam_dto import GatewayAction, GatewayActionDiscoverRuleValidate, GatewayActionDiscoverRuleValidateInputs + +from keepersdk.helpers.keeper_dag.dag_types import Statement +from keepersdk.helpers.keeper_dag.rule import Rules, ActionRuleItem, RuleItem, RuleTypeEnum, RuleActionEnum +from keepersdk.proto import pam_pb2 + + +logger = api.get_logger() + + +class PAMGatewayActionDiscoverRuleListCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action discover rule list') + PAMGatewayActionDiscoverRuleListCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--search', '-s', required=False, dest='search', action='store', + help='Search for rules.') + + @staticmethod + def print_rule_table(rule_list: List[RuleItem]): + + logger.info("") + logger.info(f"{'Rule ID'.ljust(15, ' ')} " + f"{'Name'.ljust(20, ' ')} " + f"{'Action'.ljust(6, ' ')} " + f"{'Priority'.ljust(8, ' ')} " + f"{'Case'.ljust(12, ' ')} " + f"{'Added'.ljust(19, ' ')} " + f"{'Shared Folder UID'.ljust(22, ' ')} " + f"{'Admin UID'.ljust(22, ' ')} " + "Rule" + ) + + logger.info(f"{''.ljust(15, '=')} " + f"{''.ljust(20, '=')} " + f"{''.ljust(6, '=')} " + f"{''.ljust(8, '=')} " + f"{''.ljust(12, '=')} " + f"{''.ljust(19, '=')} " + f"{''.ljust(22, '=')} " + f"{''.ljust(22, '=')} " + f"{''.ljust(10, '=')} ") + + for rule in rule_list: + if rule.case_sensitive: + ignore_case_str = "Sensitive" + else: + ignore_case_str = "Insensitive" + + shared_folder_uid = "" + if rule.shared_folder_uid is not None: + shared_folder_uid = rule.shared_folder_uid + + admin_uid = "" + if rule.admin_uid is not None: + admin_uid = rule.admin_uid + + name = "" + if rule.name is not None: + name = rule.name + + action_value = f"NONE" + if rule.action is not None: + color = "" + action_value = rule.action.value + + logger.info(f"{rule.rule_id.ljust(14, ' ')} " + f"{name[:20].ljust(20, ' ')} " + f"{color}{action_value.ljust(6, ' ')} " + f"{str(rule.priority).rjust(8, ' ')} " + f"{ignore_case_str.ljust(12, ' ')} " + f"{rule.added_ts_str.ljust(19, ' ')} " + f"{shared_folder_uid.ljust(22, ' ')} " + f"{admin_uid.ljust(22, ' ')} " + f"{Rules.make_action_rule_statement_str(rule.statement)}") + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + configuration_uid = kwargs.get('configuration_uid') + vault = context.vault + gateway_context = GatewayContext.from_gateway(vault=vault, + gateway_uid=gateway, + configuration_uid=configuration_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + rules = Rules(record=gateway_context.configuration, context=context) + rule_list = rules.rule_list(rule_type=RuleTypeEnum.ACTION, + search=kwargs.get("search")) # type: List[RuleItem] + if len(rule_list) == 0: + logger.info("") + text = f"There are no rules. " \ + f"Use 'pam action discover rule add -g {gateway_context.gateway_uid} " + if configuration_uid: + text += f"-c {gateway_context.configuration_uid}' " + text += f"to create rules." + logger.info(text) + return + + self.print_rule_table(rule_list=rule_list) + + +class PAMGatewayActionDiscoverRuleAddCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action discover rule add') + PAMGatewayActionDiscoverRuleAddCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + + parser.add_argument('--action', '-a', required=True, choices=['add', 'ignore', 'prompt'], + dest='rule_action', action='store', help='Action to take if rule matches') + parser.add_argument('--priority', '-p', required=True, dest='priority', action='store', type=int, + help='Rule execute priority') + parser.add_argument('--name', '-n', required=False, dest='name', action='store', type=str, + help='Rule name') + parser.add_argument('--ignore-case', required=False, dest='ignore_case', action='store_true', + help='Ignore value case. Rule value must be in lowercase.') + parser.add_argument('--shared-folder-uid', required=False, dest='shared_folder_uid', + action='store', help='Folder to place record.') + parser.add_argument('--admin-uid', required=False, dest='admin_uid', + action='store', help='Admin record UID to use for resource.') + parser.add_argument('--statement', '-s', required=True, dest='statement', action='store', + help='Rule statement') + + @staticmethod + def validate_rule_statement(context: KeeperParams, gateway_context: GatewayContext, statement: str) \ + -> List[Statement]: + + # Send rule the gateway to be validated. The rule is encrypted. It might contain sensitive information. + action_inputs = GatewayActionDiscoverRuleValidateInputs( + configuration_uid=gateway_context.configuration_uid, + statement=gateway_context.encrypt_str(statement) + ) + conversation_id = GatewayAction.generate_conversation_id() + router_response = router_utils.router_send_action_to_gateway( + context=context, + gateway_action=GatewayActionDiscoverRuleValidate( + inputs=action_inputs, + conversation_id=conversation_id), + message_type=pam_pb2.CMT_DISCOVERY, + is_streaming=False, + destination_gateway_uid_str=gateway_context.gateway_uid + ) + + data = PAMGatewayActionDiscoverCommandBase.get_response_data(router_response) + + if data is None: + raise Exception("The router returned a failure.") + elif data.get("success") is False: + error = data.get("error") + raise Exception(f"The rule does not appear valid: {error}") + + statement_struct = data.get("statementStruct") + logger.debug(f"Rule Structure = {statement_struct}") + if not isinstance(statement_struct, list): + raise Exception(f"The structured rule statement is not a list.") + ret = [] + for item in statement_struct: + ret.append( + Statement( + field=item.get("field"), + operator=item.get("operator"), + value=item.get("value") + ) + ) + + return statement_struct + + def execute(self, context: KeeperParams, **kwargs): + try: + gateway_uid = kwargs.get("gateway") + gateway_context = GatewayContext.from_gateway(context=context, + gateway_uid=gateway_uid, + configuration_uid=kwargs.get('configuration_uid')) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway_uid}.") + return + + # If we are setting the shared_folder_uid, make sure it exists. + shared_folder_uid = kwargs.get("shared_folder_uid") + if shared_folder_uid is not None: + shared_folder_uids = gateway_context._shared_folders if gateway_context._shared_folders is not None else [] + exists = next((x for x in shared_folder_uids if x["uid"] == shared_folder_uid), None) + if exists is None: + logger.error(f"The shared folder UID {shared_folder_uid} is not part of this " + f"application/gateway. Valid shared folder UID are:") + for item in shared_folder_uids: + logger.error(f"* {item['uid']} - {item['name']}") + return + + statement = kwargs.get("statement") + statement_struct = self.validate_rule_statement( + context=context, + gateway_context=gateway_context, + statement=statement + ) + + shared_folder_uid = kwargs.get("shared_folder_uid") + if shared_folder_uid is not None and len(shared_folder_uid) != 22: + logger.error(f"The shared folder UID {shared_folder_uid} is not the correct length.") + return + + admin_uid = kwargs.get("admin_uid") + if admin_uid is not None and len(admin_uid) != 22: + logger.error(f"The admin UID {admin_uid} is not the correct length.") + return + + # If the rule passes its validation, then add control DAG + rules = Rules(record=gateway_context.configuration, context=context) + new_rule = ActionRuleItem( + name=kwargs.get("name"), + action=kwargs.get("rule_action"), + priority=kwargs.get("priority"), + case_sensitive=not kwargs.get("ignore_case", False), + shared_folder_uid=shared_folder_uid, + admin_uid=admin_uid, + statement=statement_struct, + enabled=True + ) + rules.add_rule(new_rule) + + logger.info(f"Rule has been added") + except Exception as err: + logger.error(f"Rule was not added: {err}") + + +class PAMGatewayActionDiscoverRuleUpdateCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action discover rule update') + PAMGatewayActionDiscoverRuleUpdateCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--rule-id', '-i', required=True, dest='rule_id', action='store', + help='Identifier for the rule') + parser.add_argument('--action', '-a', required=False, choices=['add', 'ignore', 'prompt'], + dest='rule_action', action='store', help='Update the action to take if rule matches') + parser.add_argument('--priority', '-p', required=False, dest='priority', action='store', type=int, + help='Update the rule execute priority') + parser.add_argument('--name', '-n', required=False, dest='name', action='store', type=str, + help='Rule name') + parser.add_argument('--ignore-case', required=False, dest='ignore_case', action='store_true', + help='Update the rule to ignore case') + parser.add_argument('--no-ignore-case', required=False, dest='ignore_case', action='store_false', + help='Update the rule to not ignore case') + parser.add_argument('--shared-folder-uid', required=False, dest='shared_folder_uid', + action='store', help='Update the folder to place record.') + parser.add_argument('--admin-uid', required=False, dest='admin_uid', + action='store', help='Admin record UID to use for resource.') + parser.add_argument('--clear-shared-folder-uid', required=False, dest='clear_shared_folder_uid', + action='store_true', help='Clear shared folder UID, use default.') + parser.add_argument('--clear-admin-uid', required=False, dest='clear_admin_uid', + action='store_true', help='Clear admin UID') + parser.add_argument('--statement', '-s', required=False, dest='statement', action='store', + help='Update the rule statement') + parser.add_argument('--active', required=False, dest='active', action='store_true', + help='Enable rule.') + parser.add_argument('--disable', required=False, dest='active', action='store_false', + help='Disable rule.') + parser.set_defaults(active=None, ignore_case=None) + + def execute(self, context: KeeperParams, **kwargs): + vault = context.vault + gateway = kwargs.get("gateway") + gateway_context = GatewayContext.from_gateway(vault=vault, + gateway_uid=gateway, + configuration_uid=kwargs.get('configuration_uid')) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + try: + rule_id = kwargs.get("rule_id") + rules = Rules(record=gateway_context.configuration, context=context) + rule_item = rules.get_rule_item(rule_type=RuleTypeEnum.ACTION, rule_id=rule_id) + if rule_item is None: + raise ValueError("Rule Id does not exist.") + + rule_action = kwargs.get("rule_action") + if rule_action is not None: + action = RuleActionEnum.find_enum(rule_action) + if action is None: + raise ValueError(f"The action does not look correct: {rule_action}") + rule_item.action = action + + priority = kwargs.get("priority") + if priority is not None: + logger.info(" * Changing the priority of the rule.") + rule_item.priority = priority + + ignore_case = kwargs.get("ignore_case") + if ignore_case is not None: + if ignore_case: + logger.info(" * Ignore the case of text.") + else: + logger.info(" * Make rule text case sensitive.") + + rule_item.case_sensitive = not ignore_case + + if kwargs.get("clear_shared_folder_uid"): + logger.info(" * Clearing shared folder.") + rule_item.shared_folder_uid = None + else: + shared_folder_uid = kwargs.get("shared_folder_uid") + if shared_folder_uid is not None: + if len(shared_folder_uid) != 22: + logger.error(f"The shared folder UID {shared_folder_uid} is not the correct length.") + logger.info(" * Changing shared folder UID.") + rule_item.shared_folder_uid = shared_folder_uid + + if kwargs.get("clear_admin_uid"): + logger.info(" * Clearing resource admin UID.") + rule_item.admin_uid = None + else: + admin_uid = kwargs.get("admin_uid") + if admin_uid is not None: + if len(admin_uid) != 22: + logger.error(f"The admin UID {admin_uid} is not the correct length.") + return + logger.info(" * Changing the resource admin UID.") + rule_item.admin_uid = admin_uid + + statement = kwargs.get("statement") + if statement is not None: + # validate_rule_statement will throw exceptions. + statement_struct = PAMGatewayActionDiscoverRuleAddCommand.validate_rule_statement( + context=context, + gateway_context=gateway_context, + statement=statement + ) + + logger.info(" * Changing the rule statement.") + + name = kwargs.get("name") + if name is not None: + logger.info(" * Changing the rule name.") + rule_item.name = name + + enabled = kwargs.get("active") + if enabled is not None: + if enabled: + logger.info(" * Enabling the rule.") + else: + logger.info(" * Disabling the rule.") + rule_item.enabled = enabled + + rules.update_rule(rule_item) + logger.info(f"Rule has been updated") + except Exception as err: + logger.error(f"Rule was not updated: {err}") + + +class PAMGatewayActionDiscoverRuleRemoveCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action discover rule remove') + PAMGatewayActionDiscoverRuleRemoveCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--rule-id', '-i', required=False, dest='rule_id', action='store', + help='Identifier for the rule') + parser.add_argument('--remove-all', required=False, dest='remove_all', action='store_true', + help='Remove all the rules.') + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + vault = context.vault + gateway_context = GatewayContext.from_gateway(vault=vault, + gateway=gateway, + configuration_uid=kwargs.get('configuration_uid')) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + rule_id = kwargs.get("rule_id") + remove_all = kwargs.get("remove_all") + + if rule_id is None and remove_all is None: + logger.error(f'Either --rule-id or --remove-all are required.') + return + + try: + rules = Rules(record=gateway_context.configuration, context=context) + if remove_all: + rules.remove_all(RuleTypeEnum.ACTION) + logger.info(f"All rules removed.") + else: + + rule_item = rules.get_rule_item(rule_type=RuleTypeEnum.ACTION, rule_id=rule_id) + if rule_item is None: + raise ValueError("Rule Id does not exist.") + rules.remove_rule(rule_item) + + logger.info(f"Rule has been removed.") + except Exception as err: + if remove_all: + logger.error(f"Rules have NOT been removed: {err}") + else: + logger.error(f"Rule was not removed: {err}") diff --git a/keepercli-package/src/keepercli/commands/pam/keeper_pam.py b/keepercli-package/src/keepercli/commands/pam/keeper_pam.py index 92f4b52c..39a69d61 100644 --- a/keepercli-package/src/keepercli/commands/pam/keeper_pam.py +++ b/keepercli-package/src/keepercli/commands/pam/keeper_pam.py @@ -3,16 +3,23 @@ import requests from datetime import datetime +from .pam_config import PAMConfigListCommand, PAMConfigNewCommand, PAMConfigEditCommand, PAMConfigRemoveCommand +from .pam_gateway_action import (PAMGatewayActionServerInfoCommand, PAMGatewayActionRotateCommand, + PAMGatewayActionJobCommand, PAMDiscoveryCommand, PAMActionServiceCommand, + PAMActionSaasCommand, PAMDebugCommand) from .. import base from ... import api from ...helpers import report_utils, router_utils, gateway_utils from ...params import KeeperParams + from keepersdk import utils +from keepersdk.vault import ksm_management logger = api.get_logger() + # Constants DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S' MILLISECONDS_TO_SECONDS = 1000 @@ -67,6 +74,11 @@ class PAMControllerCommand(base.GroupCommand): def __init__(self): super().__init__('PAM Controller') self.register_command(PAMGatewayCommand(), 'gateway', 'g') + self.register_command(PAMConfigCommand(), 'config', 'c') + self.register_command(PAMGatewayActionCommand(), 'action', 'a') + # self.register_command('rotation', PAMRotationCommand(), 'Manage Rotations', 'r') + # self.register_command('connection', PAMConnectionCommand(), 'Manage Connections', 'n') + # self.register_command('rbi', PAMRbiCommand(), 'Manage Remote Browser Isolation', 'b') class PAMGatewayCommand(base.GroupCommand): @@ -79,6 +91,30 @@ def __init__(self): self.register_command(PAMGatewaySetMaxInstancesCommand(), 'set-max-instances', 'smi') self.default_verb = 'list' + +class PAMConfigCommand(base.GroupCommand): + + def __init__(self): + super().__init__('PAM Configurations') + self.register_command(PAMConfigListCommand(), 'list', 'l') + self.register_command(PAMConfigNewCommand(), 'new', 'n') + self.register_command(PAMConfigEditCommand(), 'edit', 'e') + self.register_command(PAMConfigRemoveCommand(), 'remove', 'rm') + self.default_verb = 'list' + + +class PAMGatewayActionCommand(base.GroupCommand): + def __init__(self): + super().__init__('PAM Gateway Action') + self.register_command(PAMGatewayActionServerInfoCommand(), 'gateway-info', 'i') + self.register_command(PAMGatewayActionRotateCommand(), 'rotate', 'r') + self.register_command(PAMGatewayActionJobCommand(), 'job-info', 'ji') + self.register_command(PAMGatewayActionJobCommand(), 'job-cancel', 'jc') + self.register_command(PAMDiscoveryCommand(), 'discover', 'd') + self.register_command(PAMActionServiceCommand(), 'service', 's') + self.register_command(PAMActionSaasCommand(), 'saas', 'sa') + self.register_command(PAMDebugCommand(), 'debug', 'd') + class PAMGatewayListCommand(base.ArgparseCommand): def __init__(self): @@ -527,7 +563,8 @@ def execute(self, context: KeeperParams, **kwargs): token_expire_in_min = kwargs.get('token_expire_in_min') self._log_gateway_creation_params(gateway_name, ksm_app, token_expire_in_min) - one_time_token = gateway_utils.create_gateway(vault, gateway_name, ksm_app, token_expire_in_min) + ksm_app_info = ksm_management.get_secrets_manager_app(vault, ksm_app) + one_time_token = gateway_utils.create_gateway(vault, gateway_name, ksm_app_info.uid, token_expire_in_min) if is_return_value: return one_time_token @@ -650,4 +687,3 @@ def _set_max_instances(self, vault, gateway, max_instances): logger.info('%s: max instance count set to %d', gateway.controllerName, max_instances) except Exception as e: raise base.CommandError(f'Error setting max instances: {e}') - diff --git a/keepercli-package/src/keepercli/commands/pam/pam_config.py b/keepercli-package/src/keepercli/commands/pam/pam_config.py new file mode 100644 index 00000000..a1fe2d05 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/pam_config.py @@ -0,0 +1,1105 @@ +import argparse +import json +import re + +from .. import base +from ... import api +from ...helpers import report_utils, gateway_utils, folder_utils +from ...params import KeeperParams +from ..record_edit import RecordEditMixin + + +from keepersdk import utils +from keepersdk.proto import pam_pb2, record_pb2 +from keepersdk.helpers import config_utils +from keepersdk.vault import vault_online, vault_utils, vault_record, record_management +from keepersdk.helpers.pam_config_facade import PamConfigurationRecordFacade +from keepersdk.helpers.tunnel.tunnel_graph import TunnelDAG, tunnel_utils +from keepersdk.helpers.keeper_dag import dag_utils +from .. import record_edit + + +logger = api.get_logger() + + +# PAM Configuration record types +PAM_CONFIG_RECORD_TYPES = ( + 'pamAwsConfiguration', 'pamAzureConfiguration', 'pamGcpConfiguration', + 'pamDomainConfiguration', 'pamNetworkConfiguration', 'pamOciConfiguration' +) + + +class PAMConfigListCommand(base.ArgparseCommand): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam config list') + PAMConfigListCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--config', '-c', required=False, dest='pam_configuration', action='store', + help='Specific PAM Configuration UID') + parser.add_argument('--verbose', '-v', required=False, dest='verbose', action='store_true', help='Verbose') + parser.add_argument('--format', dest='format', action='store', choices=['table', 'json'], default='table', + help='Output format (table, json)') + + def execute(self, context: KeeperParams, **kwargs): + self._validate_vault_and_permissions(context) + + vault = context.vault + pam_configuration_uid = kwargs.get('pam_configuration') + is_verbose = kwargs.get('verbose') + format_type = kwargs.get('format', 'table') + + if not pam_configuration_uid: + result = self._list_all_configurations(vault, is_verbose, format_type) + if format_type == 'json' and result: + return result + else: + result = self._list_single_configuration(vault, pam_configuration_uid, is_verbose, format_type) + if format_type == 'json' and result: + return result + + if format_type == 'table': + self._print_tunneling_config(vault, pam_configuration_uid) + + def _validate_vault_and_permissions(self, context: KeeperParams): + """Validates that vault is initialized and user has enterprise admin permissions.""" + if not context.vault: + raise ValueError("Vault is not initialized, login to initialize the vault.") + base.require_enterprise_admin(context) + + def _print_tunneling_config(self, vault: vault_online.VaultOnline, config_uid: str): + """Prints tunneling configuration for a specific PAM configuration.""" + encrypted_session_token, encrypted_transmission_key, _ = gateway_utils.get_keeper_tokens(vault) + tmp_dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, config_uid, is_config=True) + tmp_dag.print_tunneling_config(config_uid, None) + + def _list_single_configuration(self, vault: vault_online.VaultOnline, config_uid: str, + is_verbose: bool, format_type: str): + """Lists details for a single PAM configuration.""" + configuration = self._load_and_validate_configuration(vault, config_uid, format_type) + facade = self._create_facade(configuration) + shared_folder = self._load_shared_folder(vault, facade.folder_uid) + + if format_type == 'json': + return self._format_single_config_json(configuration, facade, shared_folder) + else: + self._format_single_config_table(configuration, facade, shared_folder) + + def _list_all_configurations(self, vault: vault_online.VaultOnline, is_verbose: bool, format_type: str): + """Lists all PAM configurations.""" + configs_data = [] + table = [] + headers = self._build_list_headers(is_verbose, format_type) + + for config_record in self._find_pam_configurations(vault): + facade = self._create_facade(config_record) + shared_folder_parents = vault_utils.get_folders_for_record(vault.vault_data, config_record.record_uid) + + if not shared_folder_parents: + logger.warning(f'Following configuration is not in the shared folder: UID: %s, Title: %s', + config_record.record_uid, config_record.title) + continue + + shared_folder = shared_folder_parents[0] + full_record = vault.vault_data.load_record(config_record.record_uid) + + if format_type == 'json': + config_data = self._build_config_json_data(config_record, facade, shared_folder, full_record, is_verbose) + configs_data.append(config_data) + else: + row = self._build_config_table_row(config_record, facade, shared_folder, full_record, is_verbose) + table.append(row) + + return self._format_output(configs_data, table, headers, format_type) + + def _load_and_validate_configuration(self, vault: vault_online.VaultOnline, config_uid: str, format_type: str): + """Loads and validates a PAM configuration record.""" + configuration = vault.vault_data.load_record(config_uid) + if not configuration: + self._handle_error(format_type, f'Configuration {config_uid} not found') + + if configuration.version != 6 or not isinstance(configuration, vault_record.TypedRecord): + self._handle_error(format_type, f'{config_uid} is not PAM Configuration') + + return configuration + + def _handle_error(self, format_type: str, error_message: str): + """Handles errors based on output format.""" + if format_type == 'json': + return json.dumps({"error": error_message}) + else: + raise Exception(error_message) + + def _create_facade(self, configuration): + """Creates a PAM configuration facade for the given record.""" + facade = PamConfigurationRecordFacade() + facade.record = configuration + return facade + + def _load_shared_folder(self, vault: vault_online.VaultOnline, folder_uid: str): + """Loads shared folder if it exists.""" + if folder_uid and folder_uid in vault.vault_data._shared_folders: + return vault.vault_data.load_shared_folder(folder_uid) + return None + + def _find_pam_configurations(self, vault: vault_online.VaultOnline): + """Finds all PAM configuration records.""" + for record in vault.vault_data.find_records(criteria='', record_type=None, record_version=6): + if record.record_type in PAM_CONFIG_RECORD_TYPES: + yield record + else: + logger.warning(f'Following configuration has unsupported type: UID: %s, Title: %s', + record.record_uid, record.title) + + def _build_list_headers(self, is_verbose: bool, format_type: str): + """Builds headers for the configuration list output.""" + if format_type == 'json': + headers = ['uid', 'config_name', 'config_type', 'shared_folder', 'gateway_uid', 'resource_record_uids'] + if is_verbose: + headers.append('fields') + else: + headers = ['UID', 'Config Name', 'Config Type', 'Shared Folder', 'Gateway UID', 'Resource Record UIDs'] + if is_verbose: + headers.append('Fields') + return headers + + def _extract_config_fields(self, record, is_verbose: bool): + """Extracts field data from a configuration record.""" + fields_data = {} if is_verbose else [] + + for field in record.fields: + if field.type in ('pamResources', 'fileRef'): + continue + + values = list(field.get_external_value()) + if not values: + continue + + field_name = field.external_name() + if field.type == 'schedule': + field_name = 'Default Schedule' + + value_str = ', '.join(field.get_external_value()) + if is_verbose: + fields_data[field_name] = value_str + else: + fields_data.append(f'{field_name}: {value_str}') + + return fields_data + + def _build_config_json_data(self, config_record, facade, shared_folder, full_record, is_verbose: bool): + """Builds JSON data structure for a configuration.""" + config_data = { + "uid": config_record.record_uid, + "config_name": config_record.title, + "config_type": config_record.record_type, + "shared_folder": { + "name": shared_folder.name, + "uid": shared_folder.folder_uid + }, + "gateway_uid": facade.controller_uid, + "resource_record_uids": facade.resource_ref + } + + if is_verbose: + config_data["fields"] = self._extract_config_fields(full_record, is_verbose=True) + + return config_data + + def _build_config_table_row(self, config_record, facade, shared_folder, full_record, is_verbose: bool): + """Builds a table row for a configuration.""" + row = [ + config_record.record_uid, + config_record.title, + config_record.record_type, + f'{shared_folder.name} ({shared_folder.folder_uid})', + facade.controller_uid, + facade.resource_ref + ] + + if is_verbose: + fields = self._extract_config_fields(full_record, is_verbose=False) + row.append(fields) + + return row + + def _format_output(self, configs_data, table, headers, format_type: str): + """Formats and outputs the final result.""" + if format_type == 'json': + configs_data.sort(key=lambda x: x['config_name'] or '') + return json.dumps({"configurations": configs_data}, indent=2) + else: + table.sort(key=lambda x: (x[1] or '')) + report_utils.dump_report_data(table, headers, fmt='table', filename="", row_number=False, column_width=None) + + def _format_single_config_json(self, configuration, facade, shared_folder): + """Formats a single configuration as JSON.""" + config_data = { + "uid": configuration.record_uid, + "name": configuration.title, + "config_type": configuration.record_type, + "shared_folder": { + "name": shared_folder.name if shared_folder else None, + "uid": shared_folder.shared_folder_uid if shared_folder else None + } if shared_folder else None, + "gateway_uid": facade.controller_uid, + "resource_record_uids": facade.resource_ref, + "fields": {} + } + + for field in configuration.fields: + if field.type in ('pamResources', 'fileRef'): + continue + + values = list(field.get_external_value()) + if not values: + continue + + field_name = field.external_name() + if field.type == 'schedule': + field_name = 'Default Schedule' + + config_data["fields"][field_name] = values + + return json.dumps(config_data, indent=2) + + def _format_single_config_table(self, configuration, facade, shared_folder): + """Formats a single configuration as a table.""" + table = [] + header = ['name', 'value'] + + table.append(['UID', configuration.record_uid]) + table.append(['Name', configuration.title]) + table.append(['Config Type', configuration.record_type]) + table.append(['Shared Folder', f'{shared_folder.name} ({shared_folder.shared_folder_uid})' if shared_folder else '']) + table.append(['Gateway UID', facade.controller_uid]) + table.append(['Resource Record UIDs', facade.resource_ref]) + + for field in configuration.fields: + if field.type in ('pamResources', 'fileRef'): + continue + + values = list(field.get_external_value()) + if not values: + continue + + field_name = field.external_name() + if field.type == 'schedule': + field_name = 'Default Schedule' + + table.append([field_name, values]) + + report_utils.dump_report_data(table, header, no_header=True, right_align=(0,)) + + +class PamConfigurationEditMixin(record_edit.RecordEditMixin): + pam_record_types = None + + def __init__(self): + super().__init__() + + @staticmethod + def get_pam_record_types(vault: vault_online.VaultOnline): + """Gets cached list of PAM record types.""" + if PamConfigurationEditMixin.pam_record_types is None: + rts = [x for x in vault.vault_data._custom_record_types if x.scope // 1000000 == record_pb2.RT_PAM] + PamConfigurationEditMixin.pam_record_types = [rt.id for rt in rts] + return PamConfigurationEditMixin.pam_record_types + + def parse_pam_configuration(self, vault: vault_online.VaultOnline, record: vault_record.TypedRecord, **kwargs): + """Parses PAM configuration fields: gateway, shared folder, and resource records.""" + field = self._get_or_create_pam_resources_field(record) + value = self._ensure_pam_resources_value(field) + + self._parse_gateway_uid(vault, value, kwargs) + self._parse_shared_folder_uid(vault, record, value, kwargs) + self._parse_resource_records(vault, value, kwargs) + + def _get_or_create_pam_resources_field(self, record: vault_record.TypedRecord): + """Gets or creates the pamResources field.""" + field = record.get_typed_field('pamResources') + if not field: + field = vault_record.TypedField.new_field('pamResources', {}) + record.fields.append(field) + return field + + def _ensure_pam_resources_value(self, field): + """Ensures the pamResources field has a value dictionary.""" + if len(field.value) == 0: + field.value.append({}) + return field.value[0] + + def _parse_gateway_uid(self, vault: vault_online.VaultOnline, value: dict, kwargs: dict): + """Resolves and sets the gateway UID from kwargs.""" + gateway = kwargs.get('gateway_uid') + if not gateway: + return + + gateways = gateway_utils.get_all_gateways(vault) + gateway_uid = next( + (utils.base64_url_encode(x.controllerUid) for x in gateways + if utils.base64_url_encode(x.controllerUid) == gateway + or x.controllerName.casefold() == gateway.casefold()), + None + ) + + if gateway_uid: + value['controllerUid'] = gateway_uid + + def _parse_shared_folder_uid(self, vault: vault_online.VaultOnline, record: vault_record.TypedRecord, + value: dict, kwargs: dict): + """Resolves and sets the shared folder UID from kwargs or existing record.""" + folder_name = kwargs.get('shared_folder_uid') + shared_folder_uid = None + + if folder_name: + shared_folder_uid = self._find_shared_folder_by_name_or_uid(vault, folder_name) + + if not shared_folder_uid: + shared_folder_uid = self._get_existing_shared_folder_uid(record) + + if shared_folder_uid: + value['folderUid'] = shared_folder_uid + else: + raise base.CommandError('Shared Folder not found') + + def _find_shared_folder_by_name_or_uid(self, vault: vault_online.VaultOnline, folder_name: str): + """Finds a shared folder by UID or name.""" + shared_folder_cache = vault.vault_data._shared_folders + + if folder_name in shared_folder_cache: + return folder_name + + for sf_uid in shared_folder_cache: + sf = vault.vault_data.load_shared_folder(sf_uid) + if sf and sf.name.casefold() == folder_name.casefold(): + return sf_uid + + return None + + def _get_existing_shared_folder_uid(self, record: vault_record.TypedRecord): + """Gets the existing shared folder UID from the record.""" + for f in record.fields: + if f.type == 'pamResources' and f.value and len(f.value) > 0: + return f.value[0].get('folderUid') + return None + + def _parse_resource_records(self, vault: vault_online.VaultOnline, value: dict, kwargs: dict): + """Removes resource records from the configuration.""" + remove_records = kwargs.get('remove_records') + if not remove_records: + return + + pam_record_lookup = self._build_pam_record_lookup(vault) + record_uids = set(value.get('resourceRef', [])) + + if isinstance(remove_records, list): + for r in remove_records: + record_uid = pam_record_lookup.get(r) or pam_record_lookup.get(r.lower()) + if record_uid: + record_uids.discard(record_uid) + else: + logger.warning(f'Failed to find PAM record: {r}') + + value['resourceRef'] = list(record_uids) + + def _build_pam_record_lookup(self, vault: vault_online.VaultOnline): + """Builds a lookup dictionary for PAM records by UID and title.""" + pam_record_lookup = {} + rti = PamConfigurationEditMixin.get_pam_record_types(vault) + + for r in vault.vault_data.records(): + if r.record_type in rti: + pam_record_lookup[r.record_uid] = r.record_uid + pam_record_lookup[r.title.lower()] = r.record_uid + + return pam_record_lookup + + @staticmethod + def resolve_single_record(vault: vault_online.VaultOnline, record_name: str, rec_type: str = ''): + """Resolves a single record by name and optional type.""" + for r in vault.vault_data.records(): + if r.title == record_name and (not rec_type or rec_type == r.record_type): + return r + return None + + def parse_properties(self, vault: vault_online.VaultOnline, record: vault_record.TypedRecord, **kwargs): + """Parses all configuration properties based on record type.""" + self.parse_pam_configuration(vault, record, **kwargs) + + extra_properties = [] + self._parse_common_properties(extra_properties, kwargs) + self._parse_type_specific_properties(vault, record, extra_properties, kwargs) + + if extra_properties: + parsed_fields = [record_edit.RecordEditMixin.parse_field(x) for x in extra_properties] + self.assign_typed_fields(record, parsed_fields) + + def _parse_common_properties(self, extra_properties: list, kwargs: dict): + """Parses properties common to all PAM configuration types.""" + port_mapping = kwargs.get('port_mapping') + if isinstance(port_mapping, list) and len(port_mapping) > 0: + pm = "\n".join(port_mapping) + extra_properties.append(f'multiline.portMapping={pm}') + + schedule = kwargs.get('default_schedule') + if schedule: + valid, err = validate_cron_expression(schedule, for_rotation=True) + if not valid: + raise base.CommandError(f'Invalid CRON "{schedule}" Error: {err}') + extra_properties.append(f'schedule.defaultRotationSchedule=$JSON:{{"type": "CRON", "cron": "{schedule}", "tz": "Etc/UTC"}}') + else: + extra_properties.append('schedule.defaultRotationSchedule=On-Demand') + + def _parse_type_specific_properties(self, vault: vault_online.VaultOnline, record: vault_record.TypedRecord, + extra_properties: list, kwargs: dict): + """Parses properties specific to each configuration type.""" + record_type = record.record_type + + if record_type == 'pamNetworkConfiguration': + self._parse_network_properties(extra_properties, kwargs) + elif record_type == 'pamAwsConfiguration': + self._parse_aws_properties(extra_properties, kwargs) + elif record_type == 'pamGcpConfiguration': + self._parse_gcp_properties(extra_properties, kwargs) + elif record_type == 'pamAzureConfiguration': + self._parse_azure_properties(extra_properties, kwargs) + elif record_type == 'pamDomainConfiguration': + self._parse_domain_properties(vault, record, extra_properties, kwargs) + elif record_type == 'pamOciConfiguration': + self._parse_oci_properties(extra_properties, kwargs) + + def _parse_network_properties(self, extra_properties: list, kwargs: dict): + """Parses network configuration properties.""" + network_id = kwargs.get('network_id') + if network_id: + extra_properties.append(f'text.networkId={network_id}') + + network_cidr = kwargs.get('network_cidr') + if network_cidr: + extra_properties.append(f'text.networkCIDR={network_cidr}') + + def _parse_aws_properties(self, extra_properties: list, kwargs: dict): + """Parses AWS configuration properties.""" + aws_id = kwargs.get('aws_id') + if aws_id: + extra_properties.append(f'text.awsId={aws_id}') + + access_key_id = kwargs.get('access_key_id') + if access_key_id: + extra_properties.append(f'secret.accessKeyId={access_key_id}') + + access_secret_key = kwargs.get('access_secret_key') + if access_secret_key: + extra_properties.append(f'secret.accessSecretKey={access_secret_key}') + + region_names = kwargs.get('region_names') + if region_names: + regions = '\n'.join(region_names) + extra_properties.append(f'multiline.regionNames={regions}') + + def _parse_gcp_properties(self, extra_properties: list, kwargs: dict): + """Parses GCP configuration properties.""" + gcp_id = kwargs.get('gcp_id') + if gcp_id: + extra_properties.append(f'text.pamGcpId={gcp_id}') + + service_account_key = kwargs.get('service_account_key') + if service_account_key: + extra_properties.append(f'json.pamServiceAccountKey={service_account_key}') + + google_admin_email = kwargs.get('google_admin_email') + if google_admin_email: + extra_properties.append(f'email.pamGoogleAdminEmail={google_admin_email}') + + gcp_region = kwargs.get('region_names') + if gcp_region: + regions = '\n'.join(gcp_region) + extra_properties.append(f'multiline.pamGcpRegionName={regions}') + + def _parse_azure_properties(self, extra_properties: list, kwargs: dict): + """Parses Azure configuration properties.""" + azure_id = kwargs.get('azure_id') + if azure_id: + extra_properties.append(f'text.azureId={azure_id}') + + client_id = kwargs.get('client_id') + if client_id: + extra_properties.append(f'secret.clientId={client_id}') + + client_secret = kwargs.get('client_secret') + if client_secret: + extra_properties.append(f'secret.clientSecret={client_secret}') + + subscription_id = kwargs.get('subscription_id') + if subscription_id: + extra_properties.append(f'secret.subscriptionId={subscription_id}') + + tenant_id = kwargs.get('tenant_id') + if tenant_id: + extra_properties.append(f'secret.tenantId={tenant_id}') + + resource_groups = kwargs.get('resource_groups') + if isinstance(resource_groups, list) and len(resource_groups) > 0: + rg = '\n'.join(resource_groups) + extra_properties.append(f'multiline.resourceGroups={rg}') + + def _parse_domain_properties(self, vault: vault_online.VaultOnline, record: vault_record.TypedRecord, + extra_properties: list, kwargs: dict): + """Parses domain configuration properties.""" + domain_id = kwargs.get('domain_id') + if domain_id: + extra_properties.append(f'text.pamDomainId={domain_id}') + + self._parse_domain_hostname(extra_properties, kwargs) + self._parse_domain_ssl_settings(extra_properties, kwargs) + self._parse_domain_network_settings(extra_properties, kwargs) + self._parse_domain_admin_credential(vault, record, kwargs) + + def _parse_domain_hostname(self, extra_properties: list, kwargs: dict): + """Parses domain hostname and port settings.""" + host = str(kwargs.get('domain_hostname') or '').strip() + port = str(kwargs.get('domain_port') or '').strip() + if host or port: + val = json.dumps({"hostName": host, "port": port}) + extra_properties.append(f"f.pamHostname=$JSON:{val}") + + def _parse_domain_ssl_settings(self, extra_properties: list, kwargs: dict): + """Parses domain SSL and scan settings.""" + domain_use_ssl = dag_utils.value_to_boolean(kwargs.get('domain_use_ssl')) + if domain_use_ssl is not None: + val = 'true' if domain_use_ssl else 'false' + extra_properties.append(f'checkbox.useSSL={val}') + + domain_scan_dc_cidr = dag_utils.value_to_boolean(kwargs.get('domain_scan_dc_cidr')) + if domain_scan_dc_cidr is not None: + val = 'true' if domain_scan_dc_cidr else 'false' + extra_properties.append(f'checkbox.scanDCCIDR={val}') + + def _parse_domain_network_settings(self, extra_properties: list, kwargs: dict): + """Parses domain network CIDR settings.""" + domain_network_cidr = kwargs.get('domain_network_cidr') + if domain_network_cidr: + extra_properties.append(f'text.networkCIDR={domain_network_cidr}') + + def _parse_domain_admin_credential(self, vault: vault_online.VaultOnline, record: vault_record.TypedRecord, + kwargs: dict): + """Parses and validates domain administrative credential.""" + domain_administrative_credential = kwargs.get('domain_administrative_credential') + dac = str(domain_administrative_credential or '') + + if not dac: + return + + if kwargs.get('force_domain_admin', False) is True: + if not re.search('^[A-Za-z0-9-_]{22}$', dac): + logger.warning(f'Invalid Domain Admin User UID: "{dac}" (skipped)') + dac = '' + else: + adm_rec = PamConfigurationEditMixin.resolve_single_record(vault, dac, 'pamUser') + if adm_rec and isinstance(adm_rec, vault_record.TypedRecord) and adm_rec.record_type == 'pamUser': + dac = adm_rec.record_uid + else: + logger.warning(f'Domain Admin User UID: "{dac}" not found (skipped).') + dac = '' + + if dac: + prf = record.get_typed_field('pamResources') + prf.value = prf.value or [{}] + prf.value[0]["adminCredentialRef"] = dac + + def _parse_oci_properties(self, extra_properties: list, kwargs: dict): + """Parses OCI configuration properties.""" + oci_id = kwargs.get('oci_id') + if oci_id: + extra_properties.append(f'text.pamOciId={oci_id}') + + oci_admin_id = kwargs.get('oci_admin_id') + if oci_admin_id: + extra_properties.append(f'secret.adminOcid={oci_admin_id}') + + oci_admin_public_key = kwargs.get('oci_admin_public_key') + if oci_admin_public_key: + extra_properties.append(f'secret.adminPublicKey={oci_admin_public_key}') + + oci_admin_private_key = kwargs.get('oci_admin_private_key') + if oci_admin_private_key: + extra_properties.append(f'secret.adminPrivateKey={oci_admin_private_key}') + + oci_tenancy = kwargs.get('oci_tenancy') + if oci_tenancy: + extra_properties.append(f'text.tenancyOci={oci_tenancy}') + + oci_region = kwargs.get('oci_region') + if oci_region: + extra_properties.append(f'text.regionOci={oci_region}') + + def verify_required(self, record: vault_record.TypedRecord): + """Verifies and sets default values for required fields.""" + for field in record.fields: + if field.required and len(field.value) == 0: + if field.type == 'schedule': + field.value = [{'type': 'ON_DEMAND'}] + else: + self.warnings.append(f'Empty required field: "{field.external_name()}"') + + for custom in record.custom: + if custom.required: + custom.required = False + + +# Configuration type mapping +CONFIG_TYPE_TO_RECORD_TYPE = { + 'aws': 'pamAwsConfiguration', + 'azure': 'pamAzureConfiguration', + 'local': 'pamNetworkConfiguration', + 'network': 'pamNetworkConfiguration', + 'gcp': 'pamGcpConfiguration', + 'domain': 'pamDomainConfiguration', + 'oci': 'pamOciConfiguration' +} + + +class PAMConfigNewCommand(base.ArgparseCommand, PamConfigurationEditMixin): + + def __init__(self): + self.choices = ['on', 'off', 'default'] + parser = argparse.ArgumentParser(prog='pam config new') + PAMConfigNewCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + choices = ['on', 'off', 'default'] + parser.add_argument('--config-type', '-ct', dest='config_type', action='store', + choices=['network', 'aws', 'azure'], help='PAM Configuration Type', ) + parser.add_argument('--title', '-t', dest='title', action='store', help='Title of the PAM Configuration') + parser.add_argument('--gateway', '-g', dest='gateway', action='store', help='Gateway UID or Name') + parser.add_argument('--shared-folder', '-sf', dest='shared_folder', action='store', + help='Share Folder where this PAM Configuration is stored. Should be one of the folders to ' + 'which the gateway has access to.') + parser.add_argument('--resource-record', '-rr', dest='resource_records', action='append', + help='Resource Record UID') + parser.add_argument('--schedule', '-sc', dest='default_schedule', action='store', help='Default Schedule: Use CRON syntax') + parser.add_argument('--port-mapping', '-pm', dest='port_mapping', action='append', help='Port Mapping') + network_group = parser.add_argument_group('network', 'Local network configuration') + network_group.add_argument('--network-id', dest='network_id', action='store', help='Network ID') + network_group.add_argument('--network-cidr', dest='network_cidr', action='store', help='Network CIDR') + aws_group = parser.add_argument_group('aws', 'AWS configuration') + aws_group.add_argument('--aws-id', dest='aws_id', action='store', help='AWS ID') + aws_group.add_argument('--access-key-id', dest='access_key_id', action='store', help='Access Key Id') + aws_group.add_argument('--access-secret-key', dest='access_secret_key', action='store', help='Access Secret Key') + aws_group.add_argument('--region-name', dest='region_names', action='append', help='Region Names') + azure_group = parser.add_argument_group('azure', 'Azure configuration') + azure_group.add_argument('--azure-id', dest='azure_id', action='store', help='Azure Id') + azure_group.add_argument('--client-id', dest='client_id', action='store', help='Client Id') + azure_group.add_argument('--client-secret', dest='client_secret', action='store', help='Client Secret') + azure_group.add_argument('--subscription_id', dest='subscription_id', action='store', + help='Subscription Id') + azure_group.add_argument('--tenant-id', dest='tenant_id', action='store', help='Tenant Id') + azure_group.add_argument('--resource-group', dest='resource_group', action='append', help='Resource Group') + + def execute(self, context: KeeperParams, **kwargs): + self.warnings.clear() + self._validate_vault(context) + + vault = context.vault + record_type = self._resolve_record_type(kwargs) + title = self._validate_title(kwargs) + + record = self._create_record(vault, record_type, title) + self._resolve_shared_folder_path(context, kwargs) + self.parse_properties(vault, record, **kwargs) + + gateway_uid, shared_folder_uid, admin_cred_ref = self._extract_pam_resources(record, kwargs) + self._validate_shared_folder(shared_folder_uid, kwargs) + self._warn_if_gateway_missing(gateway_uid, kwargs) + + self.verify_required(record) + self._create_and_configure_record(vault, record, shared_folder_uid, gateway_uid, admin_cred_ref, kwargs) + + self._log_warnings() + return record.record_uid + + def _validate_vault(self, context: KeeperParams): + """Validates that vault is initialized.""" + if not context.vault: + raise base.CommandError('Vault is not initialized. Login to initialize the vault.') + + def _resolve_record_type(self, kwargs: dict) -> str: + """Resolves the record type from config type parameter.""" + config_type = kwargs.get('config_type') + if not config_type: + raise base.CommandError('--config-type parameter is required') + + record_type = CONFIG_TYPE_TO_RECORD_TYPE.get(config_type) + if not record_type: + supported = ', '.join(CONFIG_TYPE_TO_RECORD_TYPE.keys()) + raise base.CommandError(f'--config-type {config_type} is not supported - supported options: {supported}') + + return record_type + + def _validate_title(self, kwargs: dict) -> str: + """Validates that title is provided.""" + title = kwargs.get('title') + if not title: + raise base.CommandError('--title parameter is required') + return title + + def _create_record(self, vault: vault_online.VaultOnline, record_type: str, title: str): + """Creates a new typed record with the specified type and title.""" + record = vault_record.TypedRecord(version=6) + record.type_name = record_type + record.title = title + + record_type_def = vault.vault_data.get_record_type_by_name(record_type) + if record_type_def and record_type_def.fields: + RecordEditMixin.adjust_typed_record_fields(record, record_type_def.fields) + + return record + + def _resolve_shared_folder_path(self, context: KeeperParams, kwargs: dict): + """Resolves shared folder path to UID.""" + sf_name = kwargs.get('shared_folder_uid', '') + if not sf_name: + return + + fpath = folder_utils.try_resolve_path(context, sf_name) + if fpath and len(fpath) >= 2 and fpath[-1] == '': + sfuid = fpath[-2].uid + if sfuid: + kwargs['shared_folder_uid'] = sfuid + + def _extract_pam_resources(self, record: vault_record.TypedRecord, kwargs: dict): + """Extracts gateway UID, shared folder UID, and admin credential ref from record.""" + field = record.get_typed_field('pamResources') + if not field: + raise base.CommandError('PAM configuration record does not contain resource field') + + value = field.get_default_value(dict) + if not value: + return None, None, None + + gateway_uid = value.get('controllerUid') + shared_folder_uid = value.get('folderUid') + admin_cred_ref = None + + if record.record_type == 'pamDomainConfiguration' and not kwargs.get('force_domain_admin', False): + admin_cred_ref = value.get('adminCredentialRef') + + return gateway_uid, shared_folder_uid, admin_cred_ref + + def _validate_shared_folder(self, shared_folder_uid: str, kwargs: dict): + """Validates that shared folder UID is present.""" + if not shared_folder_uid: + raise base.CommandError('--shared-folder parameter is required to create a PAM configuration') + + def _warn_if_gateway_missing(self, gateway_uid: str, kwargs: dict): + """Warns if gateway is not found.""" + if not gateway_uid: + gw_name = kwargs.get('gateway_uid') or '' + logger.warning(f'Gateway "{gw_name}" not found.') + + def _create_and_configure_record(self, vault: vault_online.VaultOnline, record: vault_record.TypedRecord, + shared_folder_uid: str, gateway_uid: str, admin_cred_ref: str, kwargs: dict): + """Creates the record and configures tunneling, DAG, and controller.""" + config_utils.pam_configuration_create_record_v6(vault, record, shared_folder_uid) + + self._configure_tunneling(vault, record, admin_cred_ref, kwargs) + + vault.sync_down() + record_management.move_vault_objects(vault, [record.record_uid], shared_folder_uid) + vault.sync_down() + + if gateway_uid: + self._set_configuration_controller(vault, record.record_uid, gateway_uid) + + def _configure_tunneling(self, vault: vault_online.VaultOnline, record: vault_record.TypedRecord, + admin_cred_ref: str, kwargs: dict): + """Configures tunneling settings for the configuration.""" + encrypted_session_token, encrypted_transmission_key, _ = tunnel_utils.get_keeper_tokens(vault) + tmp_dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, + record_uid=record.record_uid, is_config=True) + + tmp_dag.edit_tunneling_config( + kwargs.get('connections'), + kwargs.get('tunneling'), + kwargs.get('rotation'), + kwargs.get('recording'), + kwargs.get('typescriptrecording'), + kwargs.get('remotebrowserisolation') + ) + + if admin_cred_ref: + tmp_dag.link_user_to_config_with_options(admin_cred_ref, is_admin='on') + + tmp_dag.print_tunneling_config(record.record_uid, None) + + def _set_configuration_controller(self, vault: vault_online.VaultOnline, config_uid: str, gateway_uid: str): + """Sets the controller for the PAM configuration.""" + pcc = pam_pb2.PAMConfigurationController() + pcc.configurationUid = utils.base64_url_decode(config_uid) + pcc.controllerUid = utils.base64_url_decode(gateway_uid) + vault.keeper_auth.execute_auth_rest('pam/set_configuration_controller', pcc) + + def _log_warnings(self): + """Logs all warnings.""" + for w in self.warnings: + logger.warning(w) + + +class PAMConfigEditCommand(base.ArgparseCommand, PamConfigurationEditMixin): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam config edit') + PAMConfigEditCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + choices = ['on', 'off', 'default'] + parser.add_argument('--config-type', '-ct', dest='config_type', action='store', + choices=['network', 'aws', 'azure'], help='PAM Configuration Type', ) + parser.add_argument('--title', '-t', dest='title', action='store', help='Title of the PAM Configuration') + parser.add_argument('--gateway', '-g', dest='gateway', action='store', help='Gateway UID or Name') + parser.add_argument('--shared-folder', '-sf', dest='shared_folder', action='store', + help='Share Folder where this PAM Configuration is stored. Should be one of the folders to ' + 'which the gateway has access to.') + parser.add_argument('--resource-record', '-rr', dest='resource_records', action='append', + help='Resource Record UID') + parser.add_argument('--schedule', '-sc', dest='default_schedule', action='store', help='Default Schedule: Use CRON syntax') + parser.add_argument('--port-mapping', '-pm', dest='port_mapping', action='append', help='Port Mapping') + network_group = parser.add_argument_group('network', 'Local network configuration') + network_group.add_argument('--network-id', dest='network_id', action='store', help='Network ID') + network_group.add_argument('--network-cidr', dest='network_cidr', action='store', help='Network CIDR') + aws_group = parser.add_argument_group('aws', 'AWS configuration') + aws_group.add_argument('--aws-id', dest='aws_id', action='store', help='AWS ID') + aws_group.add_argument('--access-key-id', dest='access_key_id', action='store', help='Access Key Id') + aws_group.add_argument('--access-secret-key', dest='access_secret_key', action='store', help='Access Secret Key') + aws_group.add_argument('--region-name', dest='region_names', action='append', help='Region Names') + azure_group = parser.add_argument_group('azure', 'Azure configuration') + azure_group.add_argument('--azure-id', dest='azure_id', action='store', help='Azure Id') + azure_group.add_argument('--client-id', dest='client_id', action='store', help='Client Id') + azure_group.add_argument('--client-secret', dest='client_secret', action='store', help='Client Secret') + azure_group.add_argument('--subscription_id', dest='subscription_id', action='store', + help='Subscription Id') + azure_group.add_argument('--tenant-id', dest='tenant_id', action='store', help='Tenant Id') + azure_group.add_argument('--resource-group', dest='resource_group', action='append', help='Resource Group') + + def execute(self, context: KeeperParams, **kwargs): + self.warnings.clear() + self._validate_vault(context) + + vault = context.vault + configuration = self._find_configuration(vault, kwargs.get('config')) + self._validate_configuration(configuration, kwargs.get('config')) + + self._update_record_type_if_needed(vault, configuration, kwargs) + self._update_title_if_provided(configuration, kwargs) + + orig_gateway_uid, orig_shared_folder_uid = self._get_original_values(configuration) + self.parse_properties(vault, configuration, **kwargs) + self.verify_required(configuration) + + record_management.update_record(vault, configuration) + self._update_controller_and_folder_if_changed(vault, configuration, orig_gateway_uid, orig_shared_folder_uid) + + self._log_warnings() + vault.sync_down() + + def _validate_vault(self, context: KeeperParams): + """Validates that vault is initialized.""" + if not context.vault: + raise base.CommandError('Vault is not initialized. Login to initialize the vault.') + + def _find_configuration(self, vault: vault_online.VaultOnline, config_name: str): + """Finds a PAM configuration by UID or name.""" + if not config_name: + return None + + if config_name in vault.vault_data._records: + return vault.vault_data.load_record(config_name) + + config_name_lower = config_name.casefold() + for record in vault.vault_data.find_records(record_type=None, record_version=6): + if record.title.casefold() == config_name_lower: + return record + + return None + + def _validate_configuration(self, configuration, config_name: str): + """Validates that the configuration exists and is valid.""" + if not configuration: + raise base.CommandError(f'PAM configuration "{config_name}" not found') + + if not isinstance(configuration, vault_record.TypedRecord) or configuration.version != 6: + raise base.CommandError(f'PAM configuration "{config_name}" not found') + + def _update_record_type_if_needed(self, vault: vault_online.VaultOnline, configuration: vault_record.TypedRecord, + kwargs: dict): + """Updates the record type if config_type is provided and different.""" + config_type = kwargs.get('config_type') + if not config_type: + return + + record_type = CONFIG_TYPE_TO_RECORD_TYPE.get(config_type, configuration.record_type) + + if record_type != configuration.record_type: + configuration.type_name = record_type + record_type_def = vault.vault_data.get_record_type_by_name(record_type) + if record_type_def and record_type_def.fields: + RecordEditMixin.adjust_typed_record_fields(configuration, record_type_def.fields) + + def _update_title_if_provided(self, configuration: vault_record.TypedRecord, kwargs: dict): + """Updates the title if provided.""" + title = kwargs.get('title') + if title: + configuration.title = title + + def _get_original_values(self, configuration: vault_record.TypedRecord): + """Gets the original gateway and shared folder UIDs before updates.""" + field = configuration.get_typed_field('pamResources') + if not field: + raise base.CommandError('PAM configuration record does not contain resource field') + + value = field.get_default_value(dict) + if value: + return value.get('controllerUid') or '', value.get('folderUid') or '' + + return '', '' + + def _update_controller_and_folder_if_changed(self, vault: vault_online.VaultOnline, + configuration: vault_record.TypedRecord, + orig_gateway_uid: str, orig_shared_folder_uid: str): + """Updates controller and shared folder if they changed.""" + field = configuration.get_typed_field('pamResources') + value = field.get_default_value(dict) + if not value: + return + + gateway_uid = value.get('controllerUid') or '' + if gateway_uid != orig_gateway_uid: + self._set_configuration_controller(vault, configuration.record_uid, gateway_uid) + + shared_folder_uid = value.get('folderUid') or '' + if shared_folder_uid != orig_shared_folder_uid: + record_management.move_vault_objects(vault, [configuration.record_uid], shared_folder_uid) + + def _set_configuration_controller(self, vault: vault_online.VaultOnline, config_uid: str, gateway_uid: str): + """Sets the controller for the PAM configuration.""" + pcc = pam_pb2.PAMConfigurationController() + pcc.configurationUid = utils.base64_url_decode(config_uid) + pcc.controllerUid = utils.base64_url_decode(gateway_uid) + vault.keeper_auth.execute_auth_rest('pam/set_configuration_controller', pcc) + + def _log_warnings(self): + """Logs all warnings.""" + for w in self.warnings: + logger.warning(w) + + +class PAMConfigRemoveCommand(base.ArgparseCommand): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam config remove') + PAMConfigRemoveCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--config', '-c', required=True, dest='pam_config', action='store', + help='PAM Configuration UID. To view all rotation settings with their UIDs, use command `pam config list`') + + def execute(self, context: KeeperParams, **kwargs): + self._validate_vault(context) + + vault = context.vault + pam_config_name = kwargs.get('pam_config') + pam_config_uid = self._find_configuration_uid(vault, pam_config_name) + + if not pam_config_uid: + raise base.CommandError(f'Configuration "{pam_config_name}" not found') + + record_management.delete_vault_objects(vault, [pam_config_uid]) + vault.sync_down() + + def _validate_vault(self, context: KeeperParams): + """Validates that vault is initialized.""" + if not context.vault: + raise base.CommandError('Vault is not initialized. Login to initialize the vault.') + + def _find_configuration_uid(self, vault: vault_online.VaultOnline, config_name: str) -> str: + """Finds a PAM configuration UID by UID or name.""" + if not config_name: + return None + + if config_name in vault.vault_data._records: + record = vault.vault_data.load_record(config_name) + if record and record.version == 6 and record.record_type in PAM_CONFIG_RECORD_TYPES: + return config_name + + config_name_lower = config_name.casefold() + for record in vault.vault_data.find_records(record_type=None, record_version=6): + if record.record_type in PAM_CONFIG_RECORD_TYPES: + if record.record_uid == config_name or record.title.casefold() == config_name_lower: + return record.record_uid + + return None + + +def validate_cron_field(field: str, min_val: int, max_val: int) -> bool: + # Accept *, single number, range, step, list, and L suffix for last day/week + pattern = r'^(\*|\d+L?|L[W]?|\d+-\d+|\*/\d+|\d+(,\d+)*|\d+-\d+/\d+)$' + if not re.match(pattern, field): + return False + + def is_valid_number(n: str) -> bool: + # Strip L and W suffix if present (for last day/week expressions) + n_stripped = n.rstrip('LW') + return n_stripped and n_stripped.isdigit() and min_val <= int(n_stripped) <= max_val + + parts = re.split(r'[,\-/]', field) + return all(part == '*' or part in ('L', 'LW') or is_valid_number(part) for part in parts if part != '*') + + +def validate_cron_expression(expr: str, for_rotation: bool = False) -> tuple[bool, str]: + parts = expr.strip().split() + + # All internal docs, MRD etc. specify that rotation schedule is using CRON format + # but actually back-end don't accept all valid standard CRON and uses unspecified custom CRON format + if for_rotation is True: + if len(parts) != 6: + return False, f"CRON: Rotation schedules require all 6 parts incl. seconds - ex. Daily at 04:00:00 cron: 0 0 4 * * ? got {len(parts)} parts" + if not(parts[3] == '?' or parts[5] == "?"): + logger.warning("CRON: Rotation schedule CRON format - must use ? character in one of these fields: day-of-week, day-of-month") + parts[3] = '*' if parts[3] == '?' else parts[3] + parts[5] = '*' if parts[5] == '?' else parts[5] + logger.debug("WARNING! Validating CRON expression for rotation - if you get 500 type errors make sure to validate your CRON using web vault UI") + + if len(parts) not in [5, 6]: + return False, f"CRON: Expected 5 or 6 fields, got {len(parts)}" + + if len(parts) == 6: + seconds, minute, hour, dom, month, dow = parts + if not validate_cron_field(seconds, 0, 59): + return False, "CRON: Invalid seconds field" + else: + minute, hour, dom, month, dow = parts + + validators = [ + (minute, 0, 59, "minute"), + (hour, 0, 23, "hour"), + (dom, 1, 31, "day of month"), + (month, 1, 12, "month"), + (dow, 0, 7, "day of week") + ] + + for field, min_val, max_val, name in validators: + if not validate_cron_field(field, min_val, max_val): + return False, f"CRON: Invalid {name} field" + + return True, "Valid cron expression" + diff --git a/keepercli-package/src/keepercli/commands/pam/pam_dto.py b/keepercli-package/src/keepercli/commands/pam/pam_dto.py new file mode 100644 index 00000000..4d55e095 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/pam_dto.py @@ -0,0 +1,157 @@ + +import abc +import json + +from keepersdk import crypto, utils + + +class GatewayAction(metaclass=abc.ABCMeta): + + def __init__(self, action, is_scheduled, gateway_destination=None, inputs=None, conversation_id=None, message_id=None): + self.action = action + self.is_scheduled = is_scheduled + self.gateway_destination = gateway_destination + self.inputs = inputs + self.conversationId = conversation_id + # messageId is derived from conversationId for WebRTC sessions + self.messageId = message_id + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + @staticmethod + def generate_conversation_id(is_bytes=False): + message_id_bytes = crypto.get_random_bytes(16) + if is_bytes: + return message_id_bytes + else: + message_id = utils.base64_url_encode(message_id_bytes) + return message_id + + @staticmethod + def conversation_id_to_message_id(conversation_id): + """Convert conversationId to messageId format (replace + with -, / with _)""" + if conversation_id: + # Remove any padding '=' characters and replace special chars + return conversation_id.rstrip('=').replace('+', '-').replace('/', '_') + return None + + +class GatewayActionRotateInputs: + + def __init__(self, record_uid, configuration_uid, pwd_complexity_encrypted, resource_uid=None): + self.recordUid = record_uid + self.configurationUid = configuration_uid + self.pwdComplexity = pwd_complexity_encrypted + self.resourceRef = resource_uid + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionRotate(GatewayAction): + + def __init__(self, inputs: GatewayActionRotateInputs, conversation_id=None, gateway_destination=None): + super().__init__('rotate', inputs=inputs, conversation_id=conversation_id, + gateway_destination=gateway_destination, is_scheduled=True) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + +class GatewayActionJobInfoInputs: + + def __init__(self, job_id): + self.jobId = job_id + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionJobInfo(GatewayAction): + + def __init__(self, inputs: GatewayActionJobInfoInputs, conversation_id=None): + super().__init__('job-info', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionDiscoverJobStartInputs: + + def __init__(self, configuration_uid, user_map, shared_folder_uid, resource_uid=None, languages=None, + # Settings + include_machine_dir_users=False, + include_azure_aadds=False, + skip_rules=False, + skip_machines=False, + skip_databases=False, + skip_directories=False, + skip_cloud_users=False, + credentials=None): + if languages is None: + languages = ["en_US"] + self.configurationUid = configuration_uid + self.resourceUid = resource_uid + self.userMap = user_map + self.sharedFolderUid = shared_folder_uid + self.languages = languages + self.includeMachineDirUsers = include_machine_dir_users + self.includeAzureAadds = include_azure_aadds + self.skipRules = skip_rules + self.skipMachines = skip_machines + self.skipDatabases = skip_databases + self.skipDirectories = skip_directories + self.skipCloudUsers = skip_cloud_users + + if credentials is None: + credentials = [] + self.credentials = credentials + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionDiscoverJobStart(GatewayAction): + + def __init__(self, inputs: GatewayActionDiscoverJobStartInputs, conversation_id=None): + super().__init__('discover-job-start', inputs=inputs, conversation_id=conversation_id, is_scheduled=True) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionDiscoverJobRemoveInputs: + + def __init__(self, configuration_uid, job_id): + self.configurationUid = configuration_uid + self.jobId = job_id + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + +class GatewayActionDiscoverJobRemove(GatewayAction): + + def __init__(self, inputs: GatewayActionDiscoverJobRemoveInputs, conversation_id=None): + super().__init__('discover-job-remove', inputs=inputs, conversation_id=conversation_id, is_scheduled=True) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + +class GatewayActionDiscoverRuleValidateInputs: + + def __init__(self, configuration_uid, statement): + self.configurationUid = configuration_uid + self.statement = statement + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + +class GatewayActionDiscoverRuleValidate(GatewayAction): + + def __init__(self, inputs: GatewayActionDiscoverRuleValidateInputs, conversation_id=None): + super().__init__('discover-rule-validate', inputs=inputs, conversation_id=conversation_id, + is_scheduled=True) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + diff --git a/keepercli-package/src/keepercli/commands/pam/pam_gateway_action.py b/keepercli-package/src/keepercli/commands/pam/pam_gateway_action.py new file mode 100644 index 00000000..209c21b9 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/pam_gateway_action.py @@ -0,0 +1,563 @@ +import argparse +import re +import time +from typing import Dict +from urllib.parse import urlparse, urlunparse + +from keepersdk import crypto, utils +from keepersdk.proto import APIRequest_pb2 +from keepersdk.vault import vault_record, vault_online + +from ...commands import base +from ... import api +from ...params import KeeperParams +from ...helpers import router_utils, timeout_utils, email_utils, record_utils +from ...commands.pam.pam_dto import GatewayAction, GatewayActionRotate, GatewayActionRotateInputs, GatewayActionJobInfoInputs, GatewayActionJobInfo +from .discovery.discover import (PAMGatewayActionDiscoverJobStartCommand, PAMGatewayActionDiscoverJobStatusCommand, + PAMGatewayActionDiscoverJobRemoveCommand, PAMGatewayActionDiscoverResultProcessCommand, PAMDiscoveryRuleCommand) +from .service.service_commands import PAMActionServiceListCommand, PAMActionServiceAddCommand, PAMActionServiceRemoveCommand +from .saas.saas_commands import PAMActionSaasConfigCommand, PAMActionSaasSetCommand, PAMActionSaasRemoveCommand, PAMActionSaasUserCommand, PAMActionSaasUpdateCommand +from .debug.debug_info import PAMDebugInfoCommand +from .debug.debug_gateway import PAMDebugGatewayCommand +from .debug.debug_graph import PAMDebugGraphCommand +from .debug.debug_acl import PAMDebugACLCommand +from .debug.debug_link import PAMDebugLinkCommand +from .debug.debug_rs import PAMDebugRotationSettingsCommand +from .debug.debug_vertex import PAMDebugVertexCommand + +from keepersdk.proto import pam_pb2 +from keepersdk.helpers import pam_config_facade, config_utils +from keepersdk.helpers.tunnel import tunnel_graph, tunnel_utils +from keepersdk.helpers.keeper_dag import record_link as record_link_utils, dag_utils + +logger = api.get_logger() + + +class PAMGatewayActionServerInfoCommand(base.ArgparseCommand): + def __init__(self): + parser = argparse.ArgumentParser(prog='dr-info-command') + PAMGatewayActionServerInfoCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=False, dest='gateway_uid', action='store', help='Gateway UID') + parser.add_argument('--verbose', '-v', required=False, dest='verbose', action='store_true', help='Verbose Output') + + def execute(self, context: KeeperParams, **kwargs): + destination_gateway_uid_str = kwargs.get('gateway_uid') + is_verbose = kwargs.get('verbose') + router_response = router_utils.router_send_action_to_gateway( + context=context, + gateway_action=GatewayAction(action='server_info', is_scheduled=False), + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=destination_gateway_uid_str + ) + + router_utils.print_router_response(router_response, 'gateway_info', is_verbose=is_verbose, gateway_uid=destination_gateway_uid_str) + + +class PAMGatewayActionRotateCommand(base.ArgparseCommand): + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action rotate') + PAMGatewayActionRotateCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--record-uid', '-r', dest='record_uid', action='store', help='Record UID to rotate') + parser.add_argument('--folder', '-f', dest='folder', action='store', help='Shared folder UID or title pattern to rotate') + parser.add_argument('--dry-run', '-n', dest='dry_run', default=False, action='store_true', help='Enable dry-run mode') + parser.add_argument('--self-destruct', dest='self_destruct', action='store', + metavar='[(m)inutes|(h)ours|(d)ays]', + help='Create one-time share link that expires after duration') + parser.add_argument('--email-config', dest='email_config', action='store', + help='Email configuration name to use for sending (required with --send-email)') + parser.add_argument('--send-email', dest='send_email', action='store', + help='Email address to send credentials after rotation') + parser.add_argument('--email-message', dest='email_message', action='store', + help='Custom message to include in email') + + def execute(self, context: KeeperParams, **kwargs): + if context.vault is None: + raise base.CommandError('Vault is not initialized. Login to initialize the vault.') + + vault = context.vault + + record_uid = kwargs.get('record_uid', '') + folder = kwargs.get('folder', '') + recursive = kwargs.get('recursive', False) + pattern = kwargs.get('pattern', '') # additional record title match pattern + dry_run = kwargs.get('dry_run', False) + + # Store email/share arguments as instance variables + self.self_destruct = kwargs.get('self_destruct') + self.email_config = kwargs.get('email_config') + self.send_email = kwargs.get('send_email') + self.email_message = kwargs.get('email_message') + + vault = context.vault + + # Validate email setup early (before rotation) to avoid rotating password without being able to send email + if self.send_email: + if not self.email_config: + raise base.CommandError('--send-email requires --email-config to specify email configuration') + + # Find and load email config to validate provider and dependencies + try: + config_uid = email_utils.find_email_config_record(vault, self.email_config) + email_config_obj = email_utils.load_email_config_from_record(vault, config_uid) + + # Check if required dependencies are installed for this provider + is_valid, error_message = email_utils.validate_email_provider_dependencies(email_config_obj.provider) + + if not is_valid: + raise base.CommandError(f'\n{error_message}') + + except Exception as e: + # Re-raise CommandError as-is, wrap other exceptions + if isinstance(e, base.CommandError): + raise + raise base.CommandError(f'Failed to validate email configuration: {e}') + + # record, folder or pattern - at least one required + if not record_uid and not folder: + logger.info(f'the following arguments are required: --record-uid/-r or --folder/-f') + return + + # single record UID - ignore all folder options + if not folder: + self.record_rotate(context, record_uid) + return + + # folder UID or pattern (ignore --record-uid/-r option) + folders = [] # root folders matching UID or title pattern + records = [] # record UIDs of all v3/pamUser records + + # 1. find all shared_folder/shared_folder_folder matching --folder=UID/pattern + if folder in vault.vault_data.folders(): # folder UID + fldr = vault.vault_data.get_folder(folder) + # only shared_folder can be shared to KSM App/Gateway for rotation + # but its children shared_folder_folder can contain rotation records too + if fldr.folder_type in ('shared_folder', 'shared_folder_folder'): + folders.append(folder) + else: + logger.debug(f'Folder skipped (not a shared folder/subfolder) - {folder} {fldr.name}') + else: + rx_name = self.str_to_regex(folder) + for fuid in vault.vault_data.folders(): + fldr = vault.vault_data.get_folder(fuid) + # requirement - shared folder only (not for user_folder containing shf w/ recursion) + if fldr.folder_type in ('shared_folder', 'shared_folder_folder'): + if fldr.name and rx_name.search(fldr.name): + folders.append(fldr.uid) + + folders = list(set(folders)) # Remove duplicate UIDs + # 2. pattern could match both parent and child - drop all children (w/ a matching parent) + if recursive and len(folders) > 1: + roots: Dict[str, list] = {} # group by shared_folder_uid + for fuid in folders: # no shf inside shf yet + roots.setdefault(vault.vault_data.get_folder(fuid).folder_scope_uid, []).append(fuid) + uniq = [] + for fuid in roots: + fldrs = list(set(roots[fuid])) + if len(fldrs) == 1: # no siblings + uniq.append(fldrs[0]) + elif fuid in fldrs: # parent shf is topmost + uniq.append(fuid) + else: # topmost sibling(s) + fldrset = set(fldrs) + for fldr in fldrs: + path = [] + child = fldr + while vault.vault_data.get_folder(child).folder_uid != fuid: + path.append(child) + child = vault.vault_data.get_folder(child).parent_uid + path.append(child) # add root shf + path = path[1:] if path else [] # skip child uid + if not set(path) & fldrset: # no intersect + uniq.append(fldr) + folders = list(set(uniq)) + + # 3. collect all recs pamUsers w/ rotation set-up --recursive or not + for fldr in folders: + if recursive: + logger.warning('--recursive/-a option not implemented (ignored)') + + folder_sub = vault.vault_data.get_folder(fldr) + folder_records = folder_sub.records + + if folder_records: + logger.debug(f"folder {fldr} empty - no records in folder(skipped)") + continue + for ruid in folder_records: + record = vault.vault_data.get_record(ruid) + if record and record.record_type == 'pamUser': + records.append(ruid) + records = list(set(records)) # Remove duplicate UIDs + + # 4. print number of folders and records to rotate - folders: 2+0/16, records 50,000 + logger.info(f'Selected for rotation - folders: {len(folders)}, records: {len(records)}, recursive={recursive}') + + # 5. in debug - print actual folders and records selected for rotation + if logger.isEnabledFor(logger.DEBUG): + for fldr in folders: + fobj = vault.vault_data.get_folder(fldr) + title = fobj.name if fobj else '' + logger.debug(f'Rotation Folder UID: {fldr} {title}') + for rec in records: + record = vault.vault_data.get_record(rec) + title = record.title if record else '' + logger.debug(f'Rotation Record UID: {rec} {title}') + + # 6. exit if --dry-run + if dry_run: + return + + # 7. rotate and handle any throttles (to work with 50,000 records) + for record_uid in records: + delay = 0 + while True: + try: + # Handle throttles in-loop on in-record_rotate + self.record_rotate(context, record_uid, True) + break + except Exception as e: + msg = str(e) # what is considered a throttling error... + if re.search(r"throttle", msg, re.IGNORECASE): + delay = (delay+10) % 100 # reset every 1.5 minutes + logger.debug(f'Record UID: {record_uid} was throttled (retry in {delay} sec)') + time.sleep(1+delay) + else: + logger.error(f'Record UID: {record_uid} skipped: non-throttling, non-recoverable error: {msg}') + break + + def record_rotate(self, context: KeeperParams, record_uid, slient:bool = False): + vault = context.vault + record = vault.vault_data.load_record(record_uid) + if not isinstance(record, vault_record.TypedRecord): + logger.error(f'Record [{record_uid}] is not available.') + return + + # Find record by record uid + ri = record_utils.record_rotation_get(vault, utils.base64_url_decode(record.record_uid)) + ri_pwd_complexity_encrypted = ri.pwdComplexity + if not ri_pwd_complexity_encrypted: + rule_list_dict = { + 'length': 20, + 'caps': 1, + 'lowercase': 1, + 'digits': 1, + 'special': 1, + } + ri_pwd_complexity_encrypted = utils.base64_url_encode(router_utils.encrypt_pwd_complexity(rule_list_dict, record.record_key)) + + resource_uid = None + + encrypted_session_token, encrypted_transmission_key, transmission_key = tunnel_utils.get_keeper_tokens(vault) + config_uid = tunnel_utils.get_config_uid(vault, encrypted_session_token, encrypted_transmission_key, record_uid) + if not config_uid: + # Still try it the old way + # Configuration on the UI is "Rotation Setting" + ri_rotation_setting_uid = utils.base64_url_encode(ri.configurationUid) + resource_uid = utils.base64_url_encode(ri.resourceUid) + pam_config = vault.vault_data.load_record(ri_rotation_setting_uid) + if not isinstance(pam_config, vault_record.TypedRecord): + logger.error(f'PAM Configuration [{ri_rotation_setting_uid}] is not available.') + return + facade = pam_config_facade.PamConfigurationRecordFacade() + facade.record = pam_config + + config_uid = facade.controller_uid + + if not resource_uid: + tmp_dag = tunnel_graph.TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, record.record_uid, + transmission_key=transmission_key) + resource_uid = tmp_dag.get_resource_uid(record_uid) + if not resource_uid: + # NOOP records don't need resource_uid + is_noop = False + pam_config = vault.vault_data.load_record(config_uid) + + # Check the graph for the noop setting. + record_link = record_link_utils.RecordLink(record=pam_config, + context=context, + fail_on_corrupt=False) + acl = record_link.get_acl(record_uid, pam_config.record_uid) + if acl is not None and acl.rotation_settings is not None: + is_noop = acl.rotation_settings.noop + + # If it was false in the graph, or did not exist, check the record. + if is_noop is False: + noop_field = record.get_typed_field('text', 'NOOP') + is_noop = dag_utils.value_to_boolean(noop_field.value[0]) if noop_field and noop_field.value else False + + if not is_noop: + logger.error(f'Resource UID not found for record [{record_uid}]. please configure it ' + f'"pam rotation user {record_uid} --resource RESOURCE_UID"') + return + + controller = config_utils.configuration_controller_get(vault, utils.base64_url_decode(config_uid)) + if not controller.controllerUid: + raise base.CommandError(f'Gateway UID not found for configuration ' + f'{config_uid}.') + + # Find connected controllers + enterprise_controllers_connected = router_utils.router_get_connected_gateways(vault) + + controller_from_config_bytes = controller.controllerUid + gateway_uid = utils.base64_url_encode(controller.controllerUid) + if enterprise_controllers_connected: + router_controllers = {controller.controllerUid: controller for controller in + list(enterprise_controllers_connected.controllers)} + connected_controller = router_controllers.get(controller_from_config_bytes) + + if not connected_controller: + logger.warning(f'The Gateway "{gateway_uid}" is down.') + return + else: + logger.warning(f'There are no connected gateways.') + return + + action_inputs = GatewayActionRotateInputs( + record_uid=record_uid, + configuration_uid=config_uid, + pwd_complexity_encrypted=ri_pwd_complexity_encrypted, + resource_uid=resource_uid + ) + + conversation_id = GatewayAction.generate_conversation_id() + + router_response = router_utils.router_send_action_to_gateway( + context=context, gateway_action=GatewayActionRotate(inputs=action_inputs, conversation_id=conversation_id, + gateway_destination=gateway_uid), + message_type=pam_pb2.CMT_ROTATE, is_streaming=False, + transmission_key=transmission_key, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token) + + # Handle post-rotation email/share if requested + if (self.self_destruct or self.send_email) and router_response: + try: + # Sync params to get updated record with rotated password + vault.sync_down(force=True) + # Reload record to get latest credentials + record = vault.vault_data.load_record(record_uid) + if isinstance(record, vault_record.TypedRecord): + self._handle_post_rotation_email(vault, record) + except Exception as e: + logger.warning(f'Post-rotation email handling failed: {e}') + # Don't fail the rotation if email fails + + if not slient: + router_utils.print_router_response(router_response, 'job_info', conversation_id, gateway_uid=gateway_uid) + + def _handle_post_rotation_email(self, vault: vault_online.VaultOnline, record): + """Handle email sending and share link creation after successful rotation.""" + try: + # 1. Validate email arguments + if self.send_email and not self.email_config: + logger.warning(f'--send-email requires --email-config. Skipping email.') + return + + # Track whether user explicitly requested self-destruct + user_requested_self_destruct = bool(self.self_destruct) + + # Auto-set expiration to 24 hours if send-email is used without explicit self-destruct + if self.send_email and not self.self_destruct: + self.self_destruct = '24h' + logger.info('--send-email used without --self-destruct, creating 24 hour time-based share link') + + # 2. Parse timeout and create share link + share_url = None + expiration_text = None + if self.self_destruct: + try: + # parse_timeout returns a timedelta object + expiration_period = timeout_utils.parse_timeout(self.self_destruct) + expire_seconds = int(expiration_period.total_seconds()) + + if expire_seconds <= 0: + logger.warning(f'Invalid --self-destruct value. Skipping share link.') + return + + # Calculate human-readable expiration text + if expire_seconds >= 86400: # days + days = expire_seconds // 86400 + expiration_text = f"{days} day{'s' if days > 1 else ''}" + elif expire_seconds >= 3600: # hours + hours = expire_seconds // 3600 + expiration_text = f"{hours} hour{'s' if hours > 1 else ''}" + else: # minutes + minutes = expire_seconds // 60 + expiration_text = f"{minutes} minute{'s' if minutes > 1 else ''}" + + # 3. Create one-time share link manually (same as record_edit.py) + logger.info(f'Creating one-time share link expiring in {self.self_destruct}...') + record_uid = record.record_uid + record_key = record.record_key + client_key = utils.generate_aes_key() + client_id = crypto.hmac_sha512(client_key, 'KEEPER_SECRETS_MANAGER_CLIENT_ID'.encode()) + rq = APIRequest_pb2.AddExternalShareRequest() + rq.recordUid = utils.base64_url_decode(record_uid) + rq.encryptedRecordKey = crypto.encrypt_aes_v2(record_key, client_key) + rq.clientId = client_id + rq.accessExpireOn = utils.current_milli_time() + int(expiration_period.total_seconds() * 1000) + rq.isSelfDestruct = user_requested_self_destruct + vault.keeper_auth.execute_auth_rest( + rest_endpoint='vault/external_share_add', + request=rq, + response_type=APIRequest_pb2.Device + ) + # Extract hostname from context.auth.keeper_endpoint.server + parsed = urlparse(vault.keeper_auth.keeper_endpoint.server) + server_netloc = parsed.netloc if parsed.netloc else parsed.path + share_url = urlunparse(('https', server_netloc, '/vault/share', None, None, utils.base64_url_encode(client_key))) + logger.info(f'Share link created successfully') + except Exception as e: + logger.warning(f'Failed to create share link: {e}') + return + + # 4. Send email if requested + if self.send_email and self.email_config and share_url: + try: + # Find email configuration record by name + logger.info(f'Loading email configuration: {self.email_config}') + config_uid = email_utils.find_email_config_record(vault, self.email_config) + if not config_uid: + logger.warning(f'Email configuration "{self.email_config}" not found. Skipping email.') + return + + # Load the email configuration + email_config = email_utils.load_email_config_from_record(vault, config_uid) + + # 5. Build email HTML content with share link + custom_message = self.email_message or 'Your password has been rotated. Click the link below to view your new credentials.' + + html_content = email_utils.build_onboarding_email( + share_url=share_url, + custom_message=custom_message, + record_title=record.title, + expiration=expiration_text + ) + + # 6. Send email + logger.info(f'Sending email to {self.send_email}...') + email_sender = email_utils.EmailSender(email_config) + email_sender.send( + to=self.send_email, + subject=f"Password Rotated: {record.title}", + body=html_content, + html=True + ) + + # 7. Persist OAuth tokens if refreshed + if email_config.is_oauth_provider() and email_config._oauth_tokens_updated: + logger.info('Updating OAuth tokens in email configuration record...') + email_utils.update_oauth_tokens_in_record( + vault, + config_uid, + email_config.oauth_access_token, + email_config.oauth_refresh_token, + email_config.oauth_token_expiry + ) + + logger.info(f'Email sent successfully to {self.send_email}') + + except Exception as e: + logger.warning(f'Failed to send email: {e}') + # Don't fail the rotation if email fails + return + + except Exception as e: + logger.warning(f'Error in post-rotation email handler: {e}') + # Don't fail the rotation if email fails + + def str_to_regex(self, text): + text = str(text) + try: + pattern = re.compile(text, re.IGNORECASE) + except: # re.error: yet maybe TypeError, MemoryError, RecursionError etc. + pattern = re.compile(re.escape(text), re.IGNORECASE) + logger.debug(f"regex pattern {text} failed to compile (using it as plaintext pattern)") + return pattern + +class PAMGatewayActionJobCommand(base.ArgparseCommand): + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action job') + PAMGatewayActionJobCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=False, dest='gateway_uid', action='store', + help='Gateway UID. Needed only if there are more than one gateway running') + parser.add_argument('job_id', help='Job ID') + + def execute(self, context: KeeperParams, **kwargs): + job_id = kwargs.get('job_id') + gateway_uid = kwargs.get('gateway_uid') + + logger.info(f"Job id to check [{job_id}]") + + action_inputs = GatewayActionJobInfoInputs(job_id) + + conversation_id = GatewayAction.generate_conversation_id() + router_response = router_utils.router_send_action_to_gateway( + context=context, + gateway_action=GatewayActionJobInfo(inputs=action_inputs, conversation_id=conversation_id), + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=gateway_uid, + ) + + router_utils.print_router_response(router_response, 'job_info', original_conversation_id=conversation_id, gateway_uid=gateway_uid) + + +class PAMDiscoveryCommand(base.GroupCommand): + def __init__(self): + super().__init__('PAM Discovery') + self.register_command(PAMGatewayActionDiscoverJobStartCommand(), 'start', 's') + self.register_command(PAMGatewayActionDiscoverJobStatusCommand(), 'status', 'st') + self.register_command(PAMGatewayActionDiscoverJobRemoveCommand(), 'remove', 'r') + self.register_command(PAMGatewayActionDiscoverResultProcessCommand(), 'process', 'p') + self.register_command(PAMDiscoveryRuleCommand(), 'rule', 'r') + + self.default_verb = 'status' + + +class PAMActionServiceCommand(base.GroupCommand): + + def __init__(self): + super().__init__('PAM Action Service') + self.register_command(PAMActionServiceListCommand(), 'list', 'l') + self.register_command(PAMActionServiceAddCommand(), 'add', 'a') + self.register_command(PAMActionServiceRemoveCommand(), 'remove', 'r') + self.default_verb = 'list' + + +class PAMActionSaasCommand(base.GroupCommand): + + def __init__(self): + super().__init__('PAM Action Saas') + self.register_command(PAMActionSaasConfigCommand(), 'config', 'c') + self.register_command(PAMActionSaasSetCommand(), 'set', 's') + self.register_command(PAMActionSaasRemoveCommand(), 'remove', 'r') + self.register_command(PAMActionSaasUserCommand(), 'user', 'i') + self.register_command(PAMActionSaasUpdateCommand(), 'update', 'u') + + +class PAMDebugCommand(base.GroupCommand): + + def __init__(self): + super().__init__('PAM Debug') + self.register_command(PAMDebugInfoCommand(), 'info', 'i') + self.register_command(PAMDebugGatewayCommand(), 'gateway', 'g') + self.register_command(PAMDebugGraphCommand(), 'graph', 'r') + self.register_command(PAMDebugACLCommand(), 'acl', 'c') + self.register_command(PAMDebugLinkCommand(), 'link', 'l') + self.register_command(PAMDebugRotationSettingsCommand(), 'rs-reset', 'rs') + self.register_command(PAMDebugVertexCommand(), 'vertex', 'v') + + diff --git a/keepercli-package/src/keepercli/commands/pam/saas/__init__.py b/keepercli-package/src/keepercli/commands/pam/saas/__init__.py new file mode 100644 index 00000000..52d77c6e --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/saas/__init__.py @@ -0,0 +1,366 @@ +from __future__ import annotations +import json +from ....helpers.router_utils import router_send_action_to_gateway, get_response_payload +from ...pam.pam_dto import GatewayAction +from keepersdk.proto import pam_pb2 +from keepersdk.vault import vault_record +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from keepersdk import utils +import hmac +import hashlib +import os +from pydantic import BaseModel +from typing import Optional, List, Any, TYPE_CHECKING + +from .... import api + +logger = api.get_logger() +if TYPE_CHECKING: + from ..discovery.discover import GatewayContext + from ....params import KeeperParams + from keepersdk.vault import vault_record + from keepersdk.helpers.keeper_dag.dag_vertex import DAGVertex + + +CATALOG_REPO = "Keeper-Security/discovery-and-rotation-saas-dev" + + +class GatewayActionSaasListCommandInputs: + + def __init__(self, + configuration_uid: str, + languages: Optional[List[str]] = None): + + if languages is None: + languages = ["en_US"] + + self.configurationUid = configuration_uid + self.languages = languages + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionSaasListCommand(GatewayAction): + + def __init__(self, inputs: GatewayActionSaasListCommandInputs, conversation_id=None): + super().__init__('saas-list', inputs=inputs, conversation_id=conversation_id, is_scheduled=True) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +# These are from KDNRM saas_type.py +class SaasConfigEnum(BaseModel): + value: str + desc: Optional[str] = None + code: Optional[str] = None + + +class SaasConfigItem(BaseModel): + id: str + label: str + desc: str + is_secret: bool = False + desc_code: Optional[str] = None + type: Optional[str] = "text" + code: Optional[str] = None + default_value: Optional[Any] = None + enum_values: List[SaasConfigEnum] = [] + required: bool = False + + +class SaasPluginUsage(BaseModel): + record_id: str + plugin_name: str + user_uids: List[str] = [] + + +class SaasCatalog(BaseModel): + name: str + type: str = "catalog" + author: Optional[str] = None + email: Optional[str] = None + summary: Optional[str] = None + file: Optional[str] = None + file_sig: Optional[str] = None + allows_remote_management: Optional[bool] = False + readme: Optional[str] = None + fields: List[SaasConfigItem] = [] + installed: bool = False + used_by: List[SaasPluginUsage] = [] + + @property + def file_name(self): + # `file` will be either a file name or URL to a file. + # This will just get the file name. + return self.file.split("/")[-1] if self.file else None + + +def get_gateway_saas_schema(context: KeeperParams, gateway_context: GatewayContext) -> Optional[List[dict]]: + + """ + Get a plugins list from the Gateway. + + Using the plugins from the Gateway handles problem with versions. + We can work off the builtin, and custom, plugins available to the version of the Gateway. + """ + + if gateway_context is None: + logger.error(f"The user record does not have the set gateway") + return None + + # Get schema information from the Gateway + action_inputs = GatewayActionSaasListCommandInputs( + configuration_uid=gateway_context.configuration_uid, + ) + + conversation_id = GatewayAction.generate_conversation_id() + router_response = router_send_action_to_gateway( + context=context, + gateway_action=GatewayActionSaasListCommand( + inputs=action_inputs, + conversation_id=conversation_id), + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=gateway_context.gateway_uid + ) + + if router_response is None: + logger.error(f"Did not get router response.") + return None + + response = router_response.get("response") + logger.debug(f"Router Response: {response}") + payload = get_response_payload(router_response) + data = payload.get("data") + if data is None: + raise Exception("The router returned a failure.") + elif data.get("success") is False: + error = data.get("error") + logger.debug(f"gateway returned: {error}") + logger.error(f"Could not get a list of SaaS plugins available on the gateway.") + return None + + return data.get("data", []) + + +def find_user_saas_configurations(context: KeeperParams, gateway_context: GatewayContext) -> dict: + + """ + Find all the SaaS configuration being uses by a gateway. + + + + """ + + record_linking = RecordLink(record=gateway_context.configuration, context=context, logger=logger) + + def _walk_graph(v: DAGVertex, m: dict, pv: Optional[DAGVertex] = None): + + # Skip any disabled vertices + if not v.active: + return + + # Get the record for this vertex; skip is not a pamUser. + record = context.vault.vault_data.load_record(v.uid) + if record is not None and record.record_type == "pamUser" and pv is not None: + acl = record_linking.get_acl(v.uid, pv.uid) + if acl is not None and acl.rotation_settings is not None: + config_record_uid_list = acl.rotation_settings.saas_record_uid_list + if config_record_uid_list is not None: + for config_record_uid in config_record_uid_list: + + if config_record_uid not in m: + config_record = context.vault.vault_data.load_record( + config_record_uid) + if config_record is None: + continue + + plugin_name = next((f.value for f in config_record.custom if f.label == "SaaS Type"), + None) + if plugin_name is None: + continue + + m[config_record_uid] = SaasPluginUsage( + record_id=config_record_uid, + plugin_name=plugin_name[0] + ) + + m[config_record_uid].user_uids.append(v.uid) + + for next_v in v.has_vertices(): + _walk_graph( + v=next_v, + pv=v, + m=m + ) + + usage_map = {} + _walk_graph( + v=record_linking.dag.get_root, + m=usage_map) + + return usage_map + + +def get_plugins_map(context: KeeperParams, gateway_context: GatewayContext) -> Optional[dict[str, SaasCatalog]]: + + """ + Get a map of all the available plugins. + + This will first get the latest catalog from the GitHub repo. + The catalog will contain the plugin available from the repo and built in. + + Then the Gateway is checked for custom plugin; plugins outside our control. + + The result is a dictionary, with the plugin name as the key. + + """ + + plugin_map = {} + + # #### GATEWAY PLUGINS + + # Get a list of installed plugins (custom and builtin) from the Gateway + gateway_plugins = get_gateway_saas_schema(context, gateway_context) + if gateway_plugins is None: + return None + + # Add the Gateway plugins to map; all these plugins are installed. + for plugin_dict in gateway_plugins: + plugin = SaasCatalog.model_validate(plugin_dict) # type: SaasCatalog + plugin.installed = True + plugin_map[plugin.name] = plugin + + # #### CATALOG PLUGINS + + # Get the latest release of the catalog.json + api_url = f"https://api.github.com/repos/{CATALOG_REPO}/releases/latest" + res = utils.ssl_aware_get(api_url) + if res.ok is False: + logger.info("") + logger.error(f"Could not get plugin catalog from GitHub.") + return None + release_data = res.json() + + # Find the latest release URL + assets = release_data.get("assets", []) + asset = assets[0] + download_url = asset["browser_download_url"] + logger.debug(f"download {asset['name']} from {download_url}") + + # Download the latest the catalog.yml + res = utils.ssl_aware_get(download_url) + if res.ok is False: + logger.info("") + logger.error(f"Could not download the plugin catalog from GitHub.") + return None + + # Get a mapping of all the plugins being used by the plugin name. + # The group usage by plugin name; we can have multiple configuration for the same plugin + # This return dictionary of config record UID to SaasPluginUsage + plugin_usage = find_user_saas_configurations(context, gateway_context) + plugin_usage_map = {} + for config_record_uid in plugin_usage: # type: str + plugin_name = plugin_usage[config_record_uid].plugin_name + if plugin_name not in plugin_usage_map: + plugin_usage_map[plugin_name] = [] + plugin_usage_map[plugin_name].append(plugin_usage[config_record_uid]) + + for plugin_dict in json.loads(res.content): # type: dict + if plugin_dict.get("type") == "builtin": + continue + + plugin = SaasCatalog.model_validate(plugin_dict) # type: SaasCatalog + if plugin.name in plugin_map: + logger.debug(f"found duplicate plugin {plugin.name}; using plugin from gateway.") + continue + plugin_map[plugin.name] = plugin + + return plugin_map + + +def make_script_signature(plugin_code_bytes: bytes) -> str: + + # To use HMAC, we need to have a key; the key is not a secret, we just want to make a unique digest. + this_is_not_a_secret = b"NOT_IMPORTANT" + return hmac.new(this_is_not_a_secret, plugin_code_bytes, hashlib.sha256).hexdigest() + + +def get_field_input(field, current_value: Optional[str] = None): + + logger.debug(field.model_dump_json()) + + logger.info(f"{field.label}") + logger.info(f"Description: {field.desc}") + if field.required is True: + logger.warning(f"Field is required.") + if field.type == "multiline": + logger.info(f"Enter a file path to load value from file.") + + while True: + prompt = "Enter value" + extra_text = [] + valid_values = [] + if len(field.enum_values) > 0: + valid_values = [str(x.value) for x in field.enum_values] + extra_text.append(f"Allowed values: " + + f", ".join(valid_values)) + if current_value is not None: + extra_text.append(f"Enter for current value '{current_value}'") + if field.default_value is not None: + extra_text.append(f"Enter for default value '{field.default_value}'") + if len(extra_text) > 0: + prompt += f" (" + "; ".join(extra_text) + ")" + prompt += " > " + value = input(prompt) + if value == "": + if current_value is not None: + value = current_value + elif field.default_value is not None: + value = field.default_value + elif os.path.exists(value): + with open(value, "r") as fh: + value = fh.read() + fh.close() + if len(valid_values) > 0 and value not in valid_values: + logger.error(f"{value} is not a valid value.") + continue + if value is not None: + break + if field.required is False: + break + + logger.error(f"This field is required.") + + return [value] + + +def get_record_field_value(record: vault_record.TypedRecord, label: str) -> Optional[str]: + + field = next((f for f in record.custom if f.label == label), None) + if field is None or field.value is None or len(field.value) == 0 or field.value[0] is None or field.value[0] == "": + return None + return field.value[0] + + +def set_record_field_value(record: vault_record.TypedRecord, label: str, value: str, field_type: Optional[str] = "text"): + + if value is not None and isinstance(value, list): + value = value[0] + + field = next((f for f in record.custom if f.label == label), None) + if field is None or field.value is None or len(field.value) == 0 or field.value[0] is None or field.value[0] == "": + if value is not None: + record.custom.append( + vault_record.TypedField.new_field( + field_label=label, + field_type=field_type, + field_value=[value] + ) + ) + elif value is not None: + field.value = [value] + else: + field.value = [] diff --git a/keepercli-package/src/keepercli/commands/pam/saas/saas_commands.py b/keepercli-package/src/keepercli/commands/pam/saas/saas_commands.py new file mode 100644 index 00000000..02d2d98f --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/saas/saas_commands.py @@ -0,0 +1,1115 @@ + +import argparse +import json +import logging +import os +import traceback +from typing import List, Optional +from tempfile import TemporaryDirectory + +from keepersdk.helpers.keeper_dag import dag_utils + +from .... import api +from ..discovery.__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from ..pam_dto import GatewayAction +from ....params import KeeperParams +from .... import api +from ....__init__ import __version__ +from . import ( + SaasCatalog, + get_plugins_map, + get_field_input, + make_script_signature, + get_record_field_value, + set_record_field_value, +) + +from keepersdk.vault import vault_record, vault_extensions, attachment, record_management +from keepersdk.helpers.keeper_dag.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keepersdk.helpers.keeper_dag.record_link import RecordLink +from keepersdk.helpers.keeper_dag.dag_types import UserAclRotationSettings +from keepersdk import crypto, utils +from keepersdk.proto import record_pb2 +from keepersdk.errors import KeeperApiError + +logger = api.get_logger() + + +class RecordNotConfigException(Exception): + pass + + +class GatewayActionSaasConfigCommandInputs: + + def __init__(self, + configuration_uid: str, + plugin_code: str, + gateway_context: GatewayContext, + languages: Optional[List[str]] = None, + ): + + if languages is None: + languages = ["en_US"] + + self.configurationUid = configuration_uid + self.pluginCodeEnv = gateway_context.encrypt_str(plugin_code) + self.languages = languages + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionSaasListCommand(GatewayAction): + + def __init__(self, inputs: GatewayActionSaasConfigCommandInputs, conversation_id=None): + super().__init__('saas-list', inputs=inputs, conversation_id=conversation_id, is_scheduled=True) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class PAMActionSaasConfigCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam-action-saas-config') + PAMActionSaasConfigCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--list', '-l', required=False, dest='do_list', action='store_true', + help='List available SaaS rotations.') + parser.add_argument('--plugin', '-p', required=False, dest='plugin', action='store', + help='Plugin name') + parser.add_argument('--info', required=False, dest='do_info', action='store_true', + help='Get information about a plugin or plugins being used.') + parser.add_argument('--create', required=False, dest='do_create', action='store_true', + help='Create a SaaS Plugin config record.') + parser.add_argument('--update-config-uid', '-u', required=False, dest='do_update', action='store', + help='Update an existing SaaS configuration.') + parser.add_argument('--shared-folder-uid', '-s', required=False, dest='shared_folder_uid', + action='store', help='Shared folder to store SaaS configuration.') + + @staticmethod + def _show_list(plugins: dict[str, SaasCatalog]): + + sorted_catalog = {} + if plugins: + sorted_catalog = dict(sorted(plugins.items(), key=lambda i: i[1].name)) + + sort_results = { + "custom": {"title": "Custom", "using": [], "not_using": []}, + "catalog": {"title": "Catalog", "using": [], "not_using": []}, + "builtin": {"title": "Builtin", "using": [], "not_using": []}, + } + + logger.info("") + logger.info(f"Available SaaS Plugins") + for _, plugin in sorted_catalog.items(): + plugin_type = plugin.type + status = "using" if len(plugin.used_by) is True else "not_using" + sort_results[plugin_type][status].append(plugin) + + for plugin_type in ["custom", "catalog", "builtin"]: + for status in ["not_using", "using"]: + title = sort_results[plugin_type]["title"] + for plugin in sort_results[plugin_type][status]: + summary = plugin.summary or "No description available" + name = plugin.name + desc = f" ({title}" + if status == "using": + desc += f", Using" + desc += f")" + row = f" * {name}{desc} - {summary}" + logger.info(row) + + @staticmethod + def _show_plugin_info(plugin: SaasCatalog): + logger.info("") + logger.info(f"{plugin.name}") + logger.info(f" Type: {plugin.type}") + if plugin.author and plugin.email: + logger.info(f" Author: {plugin.author} ({plugin.email})") + elif plugin.author: + logger.info(f" Author: {plugin.author}") + logger.info(f" Summary: {plugin.summary or 'No description available'}") + if plugin.readme: + logger.info(f" Documents: {plugin.readme}") + logger.info(f" Fields") + req_field = [] + opt_field = [] + for field in plugin.fields: + if field.required: + req_field.append(f" * Required: {field.label} - " + f"{field.desc}") + else: + opt_field.append(f" * Optional: {field.label} - {field.desc}") + for item in req_field: + logger.info(item) + for item in opt_field: + logger.info(item) + logger.info("") + + @staticmethod + def _create_config(context: KeeperParams, + plugin: SaasCatalog, + shared_folder_uid: str, + plugin_code_bytes: Optional[bytes] = None): + + custom_fields = [ + vault_record.TypedField.new_field( + field_type="text", + field_label="SaaS Type", + field_value=[plugin.name] + ), + vault_record.TypedField.new_field( + field_type="text", + field_label="Active", + field_value=["TRUE"] + ) + ] + + for is_required in [True, False]: + for item in plugin.fields: + if item.required is is_required: + logger.info("") + value = get_field_input(item) + if value is not None: + field_type = item.type + if field_type in ["url", "int", "number", "bool", "enum"]: + field_type = "text" + + field_args = { + "field_type": field_type, + "field_label": item.label, + "field_value": value + } + record_field = vault_record.TypedField.new_field(**field_args) + + record_field.required = True + custom_fields.append(record_field) + + logger.info("") + while True: + title = input("Title for the SaaS configuration record> ") + if title != "": + break + logger.error(f"Require a record title.") + + record = vault_record.TypedRecord() + record.type_name = "login" + record.record_uid = utils.generate_uid() + record.record_key = utils.generate_aes_key() + record.title = title + + for item in custom_fields: + record.custom.append(item) + + vault = context.vault + folder = vault.vault_data.get_folder(shared_folder_uid) + folder_key = None # type: Optional[bytes] + if folder.folder_type == 'shared_folder_folder': + shared_folder_uid = folder.folder_scope_uid + elif folder.folder_type == 'shared_folder': + shared_folder_uid = folder.folder_uid + else: + shared_folder_uid = None + if shared_folder_uid and shared_folder_uid in vault.vault_data._shared_folders: + shared_folder = vault.vault_data.get_folder(shared_folder_uid) + folder_key = shared_folder.folder_key + + add_record = record_pb2.RecordAdd() + add_record.record_uid = utils.base64_url_decode(record.record_uid) + add_record.record_key = crypto.encrypt_aes_v2(record.record_key, vault.keeper_auth.auth_context.data_key) + add_record.client_modified_time = utils.current_milli_time() + add_record.folder_type = record_pb2.user_folder + if folder: + add_record.folder_uid = utils.base64_url_decode(folder.uid) + if folder.type == 'shared_folder': + add_record.folder_type = record_pb2.shared_folder + elif folder.type == 'shared_folder_folder': + add_record.folder_type = record_pb2.shared_folder_folder + if folder_key: + add_record.folder_key = crypto.encrypt_aes_v2(record.record_key, folder_key) + + data = vault_extensions.extract_typed_record_data(record) + json_data = vault_extensions.get_padded_json_bytes(data) + add_record.data = crypto.encrypt_aes_v2(json_data, record.record_key) + + if vault.keeper_auth.auth_context.enterprise_ec_public_key: + audit_data = vault_extensions.extract_audit_data(record) + if audit_data: + add_record.audit.version = 0 + add_record.audit.data = crypto.encrypt_ec( + json.dumps(audit_data).encode('utf-8'), vault.keeper_auth.auth_context.enterprise_ec_public_key) + + rq = record_pb2.RecordsAddRequest() + rq.records.append(add_record) + rs = vault.keeper_auth.execute_auth_rest('vault/records_add', rq, response_type=record_pb2.RecordsModifyResponse) + record_rs = next((x for x in rs.records if utils.base64_url_encode(x.record_uid) == record.record_uid), None) + if record_rs: + if record_rs.status != record_pb2.RS_SUCCESS: + raise KeeperApiError(record_rs.status, rs.message) + record.revision = rs.revision + + vault.sync_down() + + # If this is not a built-in or custom script, we need to attach it to the config record. + if plugin_code_bytes is not None and plugin.file_name: + + with TemporaryDirectory() as temp_dir: + vault.sync_down() + + existing_record = vault.vault_data.load_record(record.record_uid) + if existing_record is None: + logger.error(f"Could not load the config record {record.record_uid} to attach script.") + return + + temp_file = os.path.join(temp_dir, plugin.file_name) + with open(temp_file, "wb") as fh: + fh.write(plugin_code_bytes) + fh.close() + task = attachment.FileUploadTask(temp_file) + task.title = f"{plugin.name} Script" + task.mime_type = "text/x-python" + + if plugin.file_sig: + script_signature = make_script_signature(plugin_code_bytes) + if script_signature != plugin.file_sig: + raise ValueError("The plugin signature in catalog does not match what was downloaded.") + + attachment.upload_attachments(context, existing_record, [task]) + + record.fields = [ + vault_record.TypedField.new_field( + field_type="fileRef", + field_value=list(existing_record.linked_keys.keys())) + ] + + record_management.update_record(context, existing_record) + context.vault.sync_down() + + logger.info("") + logger.info(f"Created SaaS configuration record with UID of {record.record_uid}") + logger.info("") + logger.info("Assign this configuration to a user using the following command.") + logger.info(f" pam action saas set -c {record.record_uid} -u ") + logger.info(f" See pam action saas set --help for more information.") + + def execute(self, context: KeeperParams, **kwargs): + + do_list = kwargs.get("do_list", False) + do_info = kwargs.get("do_info", False) + do_create = kwargs.get("do_create", False) + do_update = kwargs.get("do_update", False) + shared_folder_uid = kwargs.get("shared_folder_uid") + + use_plugin = kwargs.get("plugin") + gateway = kwargs.get("gateway") + configuration_uid = kwargs.get('configuration_uid') + + vault = context.vault + gateway_context = GatewayContext.from_gateway(vault=vault, gateway=gateway, configuration_uid=configuration_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return None + + plugins = get_plugins_map(context=context, gateway_context=gateway_context) + + if do_list: + self._show_list(plugins) + elif use_plugin is not None: + + if use_plugin not in plugins: + logger.error(f"Cannot find '{use_plugin}' in the catalog.") + return None + + plugin = plugins[use_plugin] + + if do_info: + self._show_plugin_info(plugin=plugin) + + elif do_create: + + shared_folders = gateway_context.get_shared_folders(vault) + if shared_folder_uid is None: + if len(shared_folders) == 1: + shared_folder_uid = shared_folders[0].get("uid") + else: + logger.error(f"Multiple shared folders found. Please use '-s' to select a shared folder.") + if next((x for x in shared_folders if x.get("uid") == shared_folder_uid), None) is None: + logger.error(f"The shared folder is not part of the gateway application.") + return None + + # For catalog plugins, we need to download the python file from GitHub. + plugin_code_bytes = None + if plugin.type == "catalog" and plugin.file: + res = utils.ssl_aware_get(plugin.file) + if res.ok is False: + logger.error(f"Could not download the script from GitHub.") + return None + plugin_code_bytes = res.content + + self._create_config( + context=context, + plugin=plugin, + shared_folder_uid=shared_folder_uid, + plugin_code_bytes=plugin_code_bytes) + elif do_update: + pass + else: + self.get_parser().print_help() + else: + if do_update: + pass + else: + self.get_parser().print_help() + + +class PAMActionSaasSetCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam action saas set') + PAMActionSaasSetCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--user-uid', '-u', required=True, dest='user_uid', action='store', + help='The UID of the User record') + parser.add_argument('--config-record-uid', '-c', required=True, dest='config_record_uid', + action='store', help='The UID of the record that has SaaS configuration') + parser.add_argument('--resource-uid', '-r', required=False, dest='resource_uid', action='store', + help='The UID of the Resource record, if needed.') + + def execute(self, context: KeeperParams, **kwargs): + + user_uid = kwargs.get("user_uid") + resource_uid = kwargs.get("resource_uid") + config_record_uid = kwargs.get("config_record_uid") + + logger.info("") + + vault = context.vault + + # Check to see if the record exists. + user_record = vault.vault_data.get_record(user_uid) + if user_record is None: + logger.error(f"The user record does not exists.") + return + + # Make sure this user is a PAM User. + if user_record.record_type != PAM_USER: + logger.error(f"The user record is not a PAM User.") + return + + record_rotation = params.record_rotation_cache.get(user_record.record_uid) + if record_rotation is not None: + configuration_uid = record_rotation.get("configuration_uid") + else: + logger.error(f"The user record does not have any rotation settings.") + return + + if configuration_uid is None: + logger.error(f"The user record does not have the configuration record set in the rotation settings.") + return + + gateway_context = GatewayContext.from_configuration_uid(vault=vault, configuration_uid=configuration_uid) + + if gateway_context is None: + logger.error(f"The user record does not have the set gateway") + return + + plugins = get_plugins_map(context=context, gateway_context=gateway_context) + if plugins is None: + return + + # Check to see if the config record exists. + config_record = vault.vault_data.get_record(config_record_uid) + if config_record is None: + logger.error(f"The SaaS configuration record does not exists.") + return + + # Make sure this config is a Login record. + + if config_record.record_type not in ["login", "saasConfiguration"]: + logger.error(f"The SaaS configuration record is not a SaaS configuration record: " + f"{config_record.record_type}") + return + + config_record = vault.vault_data.load_record(config_record_uid) + + plugin_name_field = next((x for x in config_record.custom if x.label == "SaaS Type"), None) + if plugin_name_field is None: + logger.error(f"The SaaS configuration record is missing the custom field label 'SaaS Type'") + return + + plugin_name = None + if plugin_name_field.value is not None and len(plugin_name_field.value) > 0: + plugin_name = plugin_name_field.value[0] + + if plugin_name is None: + logger.error(f"The SaaS configuration record's custom field label 'SaaS Type' does not have a value.") + return + + if plugin_name not in plugins: + logger.error(f"The SaaS configuration record's custom field label 'SaaS Type' is not supported by the " + "gateway or the value is not correct.") + return + + plugin = plugins[plugin_name] + + # Make sure the SaaS configuration record has correct custom fields. + missing_fields = [] + for field in plugin.fields: + if field.required is True and field.default_value is None: + found = next((x for x in config_record.custom if x.label == field.label), None) + if not found: + missing_fields.append(field.label.strip()) + + if len(missing_fields) > 0: + logger.error(f"The SaaS configuration record is missing the following required custom fields: " + f'{", ".join(missing_fields)}') + return + + parent_uid = gateway_context.configuration_uid + + # Not sure if SaaS type rotation should be limited to NOOP rotation. + # Allow a resource record to be used. + if resource_uid is not None: + # Check to see if the record exists. + resource_record = vault.vault_data.load_record(resource_uid) + if resource_record is None: + logger.error(f"The resource record does not exists.") + return + + # Make sure this user is a PAM User. + if user_record.record_type in [PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY]: + logger.error(f"The resource record does not have the correct record type.") + return + + parent_uid = resource_uid + + record_link = RecordLink(record=gateway_context.configuration, context=context, fail_on_corrupt=False) + acl = record_link.get_acl(user_uid, parent_uid) + if acl is None: + if resource_uid is not None: + logger.error(f"There is no relationship between the user and the resource record.") + else: + logger.error(f"There is no relationship between the user and the configuration record.") + return + + if acl.rotation_settings is None: + acl.rotation_settings = UserAclRotationSettings() + + if resource_uid is not None and acl.rotation_settings.noop is True: + logger.error(f"The rotation is flagged as No Operation, however you passed in a resource record. " + f"This combination is not allowed.") + return + + # If there is a resource record, it not NOOP. + # If there is NO resource record, it is NOOP. + # However, if this is an IAM User, don't set the NOOP + if acl.is_iam_user is False: + acl.rotation_settings.noop = resource_uid is None + + # Make sure we are not re-adding the same SaaS config. + if config_record_uid in acl.rotation_settings.saas_record_uid_list: + logger.error(f"The SaaS configuration record is already being used for this user.") + return + + acl.rotation_settings.saas_record_uid_list = [config_record_uid] + + record_link.belongs_to(user_uid, parent_uid, acl=acl) + record_link.save() + + logger.info(f"Setting {plugin_name} rotation for the user record.") + logger.info("") + + + +class PAMActionSaasRemoveCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam-action-saas-remove') + PAMActionSaasRemoveCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--user-uid', '-u', required=True, dest='user_uid', action='store', + help='The UID of the User record') + parser.add_argument('--resource-uid', '-r', required=False, dest='resource_uid', action='store', + help='The UID of the Resource record, if needed.') + + def execute(self, context: KeeperParams, **kwargs): + + user_uid = kwargs.get("user_uid") # type: str + resource_uid = kwargs.get("resource_uid") # type: str + + logger.info("") + vault = context.vault + + # Check to see if the record exists. + user_record = vault.vault_data.get_record(user_uid) + if user_record is None: + logger.error(f"The user record does not exists.") + return + + # Make sure this user is a PAM User. + if user_record.record_type != PAM_USER: + logger.error(f"The user record is not a PAM User.") + return + + record_rotation = params.record_rotation_cache.get(user_record.record_uid) + if record_rotation is not None: + configuration_uid = record_rotation.get("configuration_uid") + else: + logger.error(f"The user record does not have any rotation settings.") + return + + if configuration_uid is None: + logger.error(f"The user record does not have the configuration record set in the rotation settings.") + return + + gateway_context = GatewayContext.from_configuration_uid(vault=vault, configuration_uid=configuration_uid) + + if gateway_context is None: + logger.error(f"The user record does not have the set gateway") + return + + # Don't check config record + # Just accept the record UID; the record might not exist anymore. + + parent_uid = gateway_context.configuration_uid + + # Not sure if SaaS type rotation should be limited to NOOP rotation. + # Allow a resource record to be used. + if resource_uid is not None: + # Check to see if the record exists. + resource_record = vault.vault_data.get_record(resource_uid) + if resource_record is None: + logger.error(f"The resource record does not exists.") + return + + # Make sure this user is a PAM User. + if user_record.record_type in [PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY]: + logger.error(f"The resource record does not have the correct record type.") + return + + parent_uid = resource_uid + + record_link = RecordLink(record=gateway_context.configuration, context=context, fail_on_corrupt=False) + acl = record_link.get_acl(user_uid, parent_uid) + if acl is None: + if resource_uid is not None: + logger.error(f"There is no relationship between the user and the resource record.") + else: + logger.error(f"There is no relationship between the user and the configuration record.") + return + + if acl.rotation_settings is None: + acl.rotation_settings = UserAclRotationSettings() + + if resource_uid is not None and acl.rotation_settings.noop is True: + logger.error(f"The rotation is flagged as No Operation, however you passed in a resource record. " + f"This combination is not allowed.") + return + + # An empty array removes the SaaS config. + acl.rotation_settings.saas_record_uid_list = [] + + record_link.belongs_to(user_uid, parent_uid, acl) + record_link.save() + + logger.info(f"Removing SaaS service rotation from the user record.") + + +class PAMActionSaasUserCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam-action-saas-user') + PAMActionSaasUserCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--user-record-uid', '-u', required=True, dest='user_uid', action='store', + help='The UID of the User record') + + def execute(self, context: KeeperParams, **kwargs): + + user_uid = kwargs.get("user_uid") + + logger.info("") + vault = context.vault + + # Check to see if the record exists. + user_record = vault.vault_data.get_record(user_uid) + if user_record is None: + logger.error(f"The user record does not exists.") + return + + # Make sure this user is a PAM User. + if user_record.record_type != PAM_USER: + logger.error(f"The user record is not a PAM User.") + return + + record_rotation = params.record_rotation_cache.get(user_record.record_uid) + if record_rotation is not None: + configuration_uid = record_rotation.get("configuration_uid") + else: + logger.error(f"The user record does not have any rotation settings.") + return + + if configuration_uid is None: + logger.error(f"The user record does not have the configuration record set in the rotation settings.") + return + + gateway_context = GatewayContext.from_configuration_uid(vault=vault, configuration_uid=configuration_uid) + + if gateway_context is None: + logger.error(f"The user record does not have the set gateway") + return + + plugins = get_plugins_map(context, gateway_context) + + record_link = RecordLink(record=gateway_context.configuration, context=context, fail_on_corrupt=False) + user_vertex = record_link.get_record_link(user_uid) + if user_vertex is None: + logger.error(f"Cannot find the user in the record link graph.") + return + + logger.info(f"User: {user_record.title}") + + missing_configs = [] + + # User's can have multiple ACL edges to different parents. + # One of those ACL edges, in the rotation settings, may a populated saas_record_uid_list + for parent_vertex in user_vertex.belongs_to_vertices(): + + # Check to see if the record exists. + parent_record = vault.vault_data.get_record(parent_vertex.uid) + if parent_record is None: + logger.error(f"* Parent record UID {parent_vertex.uid} does not exists.") + logger.error(f" The record may have been deleted, however the relationship still exists.") + logger.info("") + continue + + logger.info(f" * {parent_record.title}, {parent_record.record_type}") + logger.info("") + + acl = record_link.get_acl(user_uid, parent_vertex.uid) + if acl is not None and acl.rotation_settings is not None: + saas_record_uid_list = acl.rotation_settings.saas_record_uid_list + if saas_record_uid_list is None or len(saas_record_uid_list) == 0: + logger.error(f" The user does not have any SaaS service rotations.") + return + + for config_record_uid in saas_record_uid_list: + config_record = vault.vault_data.get_record(config_record_uid) + if config_record is None: + logger.error(f" * Record UID {config_record_uid} not longer exists.") + continue + logger.info(f" {config_record.title}") + + plugin_name = "" + saas_type_field = next((x for x in config_record.custom if x.label == "SaaS Type"), None) + if (saas_type_field is not None and saas_type_field.value is not None + and len(saas_type_field.value) > 0): + plugin_name = saas_type_field.value[0] + + plugin = plugins.get(plugin_name) + + # This might have been a valid plugin, or the name is mistyped, so it's not supported. + if plugin is None: + plugin_name += " (Not Supported)" + + rotation_active = "Active" + rotation_active_field = next((x for x in config_record.custom if x.label == "Active"), + None) + + if (rotation_active_field is not None and rotation_active_field.value is not None + and len(rotation_active_field.value) > 0): + is_active = dag_utils.value_to_boolean(rotation_active_field.value[0]) + if is_active is False: + rotation_active = "Inactive" + + logger.info(f" SaaS Type: {plugin_name}") + logger.info(f" Config Record UID: {config_record.record_uid}") + logger.info(f" Active: {rotation_active}") + + if plugin is not None: + + for field in plugin.fields: + value = next((x.value for x in config_record.custom if x.label == field.label), None) + if value is not None: + if len(value) > 0: + value = value[0] + else: + value = None + if value is None: + if field.default_value is not None: + value = f"{field.default_value} (Default)" + else: + value = "Not Set" + logger.info(f" {field.label}: {value}") + logger.info("") + + +class PAMActionSaasUpdateCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam-action-saas-update') + PAMActionSaasUpdateCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--configuration-uid', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--all', '-a', required=False, dest='do_all', action='store_true', + help='Update all configurations.') + parser.add_argument('--config-record-uid', '-c', required=False, dest='config_uid', action='store', + help='Update a specific configuration.') + parser.add_argument('--dry-run', required=False, dest='do_dry_run', action='store_true', + help='Dry run. Do not save any changes.') + + @staticmethod + def get_field_values(record: vault_record.TypedRecord, field_type: str) -> List[str]: + return next( + (f.value + for f in record.fields + if f.type == field_type), + None + ) + + @classmethod + def _get_file_refs(cls, record: vault_record.TypedRecord) -> List[str]: + return list(next((x.value for x in record.fields if x.type == "fileRef"), [])) + + @classmethod + def _update_script(cls, context: KeeperParams, config_record: vault_record.TypedRecord, plugin: SaasCatalog): + + if plugin.type != "catalog": + raise ValueError("Cannot download script for non-catalog plugin.") + + if not plugin.file: + raise ValueError("Plugin does not have a file URL.") + + if not plugin.file_name: + raise ValueError("Plugin does not have a file name.") + + logger.info(" * downloading updated plugin script") + res = utils.ssl_aware_get(plugin.file) + if res.ok is False: + raise ValueError("Could download updated script from GitHub") + plugin_code_bytes = res.content + + new_script_sig = make_script_signature(plugin_code_bytes=plugin_code_bytes) + + if plugin.file_sig: + logger.debug(f"downloaded {new_script_sig} vs catalog {plugin.file_sig}") + if new_script_sig != plugin.file_sig: + raise ValueError("The plugin signature in catalog does not match what was downloaded.") + + with TemporaryDirectory() as temp_dir: + temp_file = os.path.join(temp_dir, plugin.file_name) + with open(temp_file, "wb") as fh: + fh.write(plugin_code_bytes) + fh.close() + + task = attachment.FileUploadTask(temp_file) + task.title = f"{plugin.name} Script" + task.mime_type = "text/x-python" + + # Get the existing attached; we are going to remove these + existing_file_refs = cls._get_file_refs(config_record) + logger.debug(f"existing file ref: {existing_file_refs}") + + attachment.upload_attachments(context.vault, config_record, [task]) + + new_file_refs = cls._get_file_refs(config_record) + logger.debug(f"new file ref: {new_file_refs}") + + if existing_file_refs is not None: + logger.debug("existing file ref exists") + for existing_file_ref in existing_file_refs: # type: str + logger.debug(f" * {existing_file_ref}") + if existing_file_ref in new_file_refs: + new_file_refs.remove(existing_file_ref) + else: + logger.debug("no existing file ref, use new file ref") + + logger.debug(f"save file ref: {new_file_refs}") + + config_record.fields = [ + vault_record.TypedField.new_field( + field_type="fileRef", + field_value=new_file_refs + ) + ] + + record_management.update_record(context.vault, config_record) + context.sync_data = True + + logger.info(f" * the plugin script is now up-to-date.") + + @classmethod + def _missing_fields(cls, config_record: vault_record.TypedRecord, plugin: SaasCatalog) -> List[str]: + + # Make the record into a map by the field label + records_field_map = {} + for field in config_record.custom: + records_field_map[field.label] = field + + missing_fields = [] + for field in plugin.fields: + + # We only care about required fields. + if not field.required or field.default_value is not None: + continue + record_field = records_field_map.get(field.label) + if (record_field is None + or record_field.value is None + or len(record_field.value) == 0 + or record_field.value[0] is None + or record_field.value[0] == ""): + missing_fields.append(field.label) + return missing_fields + + @classmethod + def _update_config(cls, + context: KeeperParams, + plugins: dict[str, SaasCatalog], + config_record: vault_record.TypedRecord, + dry_run: bool = False) -> Optional[SaasCatalog]: + + plugin_field = next((x for x in config_record.custom if x.label == "SaaS Type"), None) + if plugin_field is None or len(plugin_field.value) == 0: + logger.debug("record is not a SaaS Configuration record") + raise RecordNotConfigException() + plugin_name = plugin_field.value[0] + logger.debug(f"plugin name is {plugin_name}") + + plugin = plugins.get(plugin_name) + if plugin is not None and plugin.type == "catalog": + + missing_fields = cls._missing_fields(config_record=config_record, plugin=plugin) + + logger.info(f"{config_record.title} ({config_record.record_uid}) - {plugin_name}") + logger.debug(f"plugin is {plugin_name} for config {config_record.title}") + attachments = list(attachment.prepare_attachment_download(context.vault, config_record.record_uid)) + + # If there is no script, just attach script to record. + # Someone might have deleted the script from the record. + if len(attachments) == 0: + logger.info(" * the record does not contain a plugin script.") + logger.debug(" * configuration did not have script, add current script.") + + if not dry_run: + cls._update_script( + context=context, + config_record=config_record, + plugin=plugin, + ) + else: + logger.info(f" * not updating script due to dry run.") + + if len(missing_fields) == 0: + logger.info(f" * the configuration record fields are up-to-date.") + else: + logger.error(f" * the configuration record's required field(s) are missing or blank: " + f"{', '.join(missing_fields)}") + logger.info("") + return plugin + + logger.debug(f"found {len(attachments)} attached script(s).") + + if len(attachments) > 1: + raise ValueError("Found multiple scripts. Only one script is allowed per SaaS Configuration record.") + + for atta in attachments: + with TemporaryDirectory() as temp_dir: + if not plugin.file_name: + logger.debug("plugin does not have a file name, using default") + temp_file = str(os.path.join(temp_dir, f"{plugin.name}_script.py")) + else: + temp_file = str(os.path.join(temp_dir, plugin.file_name)) + logger.debug(f"download to {temp_file}") + + # download_to_file prints to the screen, we don't want that. + log_level = logger.getEffectiveLevel() + try: + logger.setLevel(logging.WARNING) + atta.download_to_file(context.vault, temp_file) + finally: + logger.setLevel(log_level) + + with open(temp_file, "rb") as fh: + plugin_code_bytes = fh.read() + fh.close() + + attach_file_sig = make_script_signature(plugin_code_bytes=plugin_code_bytes) + + if plugin.file_sig: + logger.debug(f"attached {attach_file_sig} vs catalog {plugin.file_sig}") + sig_matches = attach_file_sig == plugin.file_sig + else: + logger.debug("plugin does not have a file signature, skipping verification") + sig_matches = True + + if not sig_matches: + logger.error(f" * the plugin script have changed.") + logger.debug("the script has changed, update") + + if not dry_run: + cls._update_script( + context=context, + config_record=config_record, + plugin=plugin, + ) + else: + logger.info(f" * not updating script due to dry run.") + else: + logger.info(f" * the plugin script is up-to-date.") + + if len(missing_fields) == 0: + logger.info(f" * the configuration record fields are up-to-date.") + else: + logger.error(f" * the configuration record's required field(s) are missing or blank: " + f"{', '.join(missing_fields)}") + + # If the record type is login, migrate to saasConfiguration + if config_record.record_type == "login": + logger.info(f" * migrate record type to SaaS Configuration.") + config_record.type_name = "saasConfiguration" + record_management.update_record(context.vault, config_record) + + logger.info("") + + logger.debug("plugin doesn't used attached scripts, or bad SaaS type in config record.") + return plugin + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") # type: str + do_all = kwargs.get("do_all", False) # type: bool + config_record_uid = kwargs.get("config_uid") # type: str + do_dry_run = kwargs.get("do_dry_run", False) # type: bool + + configuration_uid = kwargs.get('configuration_uid') # type Optional[str] + vault = context.vault + + gateway_context = GatewayContext.from_gateway(context=context, + gateway=gateway, + configuration_uid=configuration_uid) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + logger.info("") + + if do_dry_run: + logger.info(f"Dry run enabled. No changes will be saved.") + logger.info("") + + plugins = get_plugins_map( + context=context, + gateway_context=gateway_context + ) + + if do_all: + logger.debug("search vault for login record types") + for record in list(vault.vault_data.find_records(criteria=None, record_type=["login", "saasConfiguration"], record_version=None)): + logger.debug("--------------------------------------------------------------------------------------") + config_record = vault.vault_data.load_record(record.record_uid) + + logger.debug(f"checking record {record.record_uid}, {record.title}") + try: + self._update_config( + context=context, + plugins=plugins, + config_record=config_record, + dry_run=do_dry_run + ) + except RecordNotConfigException: + pass + except Exception as err: + logger.error(f" *{err}") + logger.debug(traceback.format_exc()) + logger.debug(f"ERROR (no fatal): {err}") + + context.sync_data = True + + elif config_record_uid is not None: + config_record = vault.vault_data.load_record(config_record_uid) + if config_record is None: + logger.error(f"Cannot find a record for UID {config_record_uid}.") + return + + try: + plugin = self._update_config( + context=context, + plugins=plugins, + config_record=config_record, + dry_run=do_dry_run + ) + if plugin is not None: + missing_fields = self._missing_fields(config_record=config_record, plugin=plugin) + + if len(missing_fields) > 0: + + # If we added a script, we need to sync down to get the record version number correct. + vault.sync_down() + config_record = vault.vault_data.load_record(config_record_uid) + + # If the record type is login, migrate to saasConfiguration + if config_record.record_type == "login": + logger.debug("migrating from login to saasConfiguration record type") + config_record.type_name = "saasConfiguration" + + for required in [True, False]: + for field in plugin.fields: + if field.required is required: + current_value = get_record_field_value( + record=config_record, + label=field.label + ) + logger.info("") + value = get_field_input(field, current_value=current_value) + if value is not None: + set_record_field_value( + record=config_record, + label=field.label, + value=value + ) + + if not do_dry_run: + record_management.update_record(vault, config_record) + logger.info("") + logger.info(f"The SaaS configuration record has been updated.") + logger.info("") + else: + logger.info("") + logger.info(f"The SaaS configuration record was not saved due to dry run.") + logger.info("") + + vault.sync_down() + + except Exception as err: + logger.error("") + logger.debug(traceback.format_exc()) + logger.error(f"{err}.") + return + else: + logger.error("") + logger.error(f"Requires either the --all or --config-record-uid parameters.") + logger.info("") + PAMActionSaasUpdateCommand.parser.print_help() diff --git a/keepercli-package/src/keepercli/commands/pam/service/__init__.py b/keepercli-package/src/keepercli/commands/pam/service/__init__.py new file mode 100644 index 00000000..06c160ec --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/service/__init__.py @@ -0,0 +1,18 @@ +from __future__ import annotations +from keepersdk.helpers.keeper_dag.dag_utils import value_to_boolean +from keepersdk.vault import vault_online +from ....api import get_logger +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from keepersdk.helpers.keeper_dag.connection import ConnectionBase + + +def get_connection(vault: vault_online.VaultOnline) -> ConnectionBase: + if value_to_boolean(os.environ.get("USE_LOCAL_DAG", False)) is False: + from keepersdk.helpers.keeper_dag.connection.commander import Connection as CommanderConnection + return CommanderConnection(vault=vault, logger=get_logger()) + else: + from keepersdk.helpers.keeper_dag.connection.local import Connection as LocalConnection + return LocalConnection(vault=vault, logger=get_logger()) \ No newline at end of file diff --git a/keepercli-package/src/keepercli/commands/pam/service/service_commands.py b/keepercli-package/src/keepercli/commands/pam/service/service_commands.py new file mode 100644 index 00000000..3108b883 --- /dev/null +++ b/keepercli-package/src/keepercli/commands/pam/service/service_commands.py @@ -0,0 +1,365 @@ +import argparse +from ..discovery.__init__ import GatewayContext, PAMGatewayActionDiscoverCommandBase +from ....params import KeeperParams +from .... import api +from ....__init__ import __version__ +from ....commands import base + +from keepersdk.helpers.keeper_dag.user_service import UserService +from keepersdk.helpers.keeper_dag.dag_types import EdgeType, ServiceAcl, RefType +from keepersdk.helpers.keeper_dag.constants import PAM_MACHINE, PAM_USER +from keepersdk.helpers.keeper_dag.record_link import RecordLink + + +logger = api.get_logger() + + +class PAMActionServiceListCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam-action-service-list') + PAMActionServiceListCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + vault = context.vault + gateway_context = GatewayContext.from_gateway(vault=vault, + gateway=gateway, + configuration_uid=kwargs.get('configuration_uid')) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + user_service = UserService(record=gateway_context.configuration, context=context, fail_on_corrupt=False, + agent=f"Cmdr/{__version__}") + + service_map = {} + for resource_vertex in user_service.dag.get_root.has_vertices(edge_type=EdgeType.LINK): + resource_record = vault.vault_data.load_record(resource_vertex.uid) + if resource_record is None or resource_record.record_type != PAM_MACHINE: + continue + user_vertices = user_service.get_user_vertices(resource_vertex.uid) + if len(user_vertices) > 0: + for user_vertex in user_vertices: + user_record = vault.vault_data.load_record(user_vertex.uid) + if user_record is None: + continue + acl = user_service.get_acl(resource_record.record_uid, user_record.record_uid) + if acl is None or (acl.is_service is False and acl.is_task is False): + continue + if user_record.record_uid not in service_map: + service_map[user_record.record_uid] = { + "title": user_record.title, + "machines": [] + } + text = f"{resource_record.title} ({resource_record.record_uid}) :" + comma = "" + if acl.is_service: + text += f" Services" + comma = "," + if acl.is_task: + text += f"{comma} Scheduled Tasks" + if acl.is_iis_pool: + text += f"{comma} IIS Pools" + service_map[user_record.record_uid]["machines"].append(text) + + logger.info("") + printed_something = False + logger.info("User Mapping") + for user_uid in service_map: + user = service_map[user_uid] + printed_something = True + logger.info(f" {user['title']} ({user_uid})") + for machine in user["machines"]: + logger.info(f" * {machine}") + logger.info("") + if not printed_something: + logger.error(f"There are no service mappings.") + + + +class PAMActionServiceAddCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam-action-service-add') + PAMActionServiceAddCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--machine-uid', '-m', required=True, dest='machine_uid', action='store', + help='The UID of the Windows Machine record') + parser.add_argument('--user-uid', '-u', required=True, dest='user_uid', action='store', + help='The UID of the User record') + parser.add_argument('--type', '-t', required=True, choices=['service', 'task', 'iis'], dest='type', + action='store', help='Relationship to add [service, task, iis]') + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + machine_uid = kwargs.get("machine_uid") + user_uid = kwargs.get("user_uid") + rel_type = kwargs.get("type") + + logger.info("") + vault = context.vault + + gateway_context = GatewayContext.from_gateway(vault=vault, + gateway=gateway, + configuration_uid=kwargs.get('configuration_uid')) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + if gateway_context is None: + logger.error(f"Cannot get gateway information. Gateway may not be up.") + return + + user_service = UserService(record=gateway_context.configuration, context=context, fail_on_corrupt=False, + agent=f"Cmdr/{__version__}") + record_link = RecordLink(record=gateway_context.configuration, context=context, fail_on_corrupt=False, + agent=f"Cmdr/{__version__}") + + ############### + + # Check to see if the record exists. + machine_record = vault.vault_data.load_record(machine_uid) + if machine_record is None: + logger.error(f"The machine record does not exists.") + return + + # Make sure the record is a PAM Machine. + if machine_record.record_type != PAM_MACHINE: + logger.error(f"The machine record is not a PAM Machine.") + return + + # Make sure this machine is linked to the configuration record. + machine_rl = record_link.get_record_link(machine_record.record_uid) + if machine_rl is None: + logger.error(f"The machine record does not exists in the graph.") + return + + # Edges from provider and machine might be wrong. + # Should be a LINK edge, could be an ACL edge. + if (machine_rl.get_edge(record_link.dag.get_root, edge_type=EdgeType.LINK) is None and + machine_rl.get_edge(record_link.dag.get_root, edge_type=EdgeType.ACL) is None): + logger.error(f"The machine record does not belong to this gateway.") + return + + ############### + + # Check to see if the record exists. + user_record = vault.vault_data.load_record(user_uid) + if user_record is None: + logger.error(f"The user record does not exists.") + return + + # Make sure this user is a PAM User. + if user_record.record_type != PAM_USER: + logger.error(f"The user record is not a PAM User.") + return + + record_rotation = params.record_rotation_cache.get(user_record.record_uid) + if record_rotation is not None: + controller_uid = record_rotation.get("configuration_uid") + if controller_uid is None or controller_uid != gateway_context.configuration_uid: + logger.error(f"The user record does not belong to this gateway. Cannot use this user.") + return + else: + logger.error(f"The user record does not have any rotation settings.") + return + + ######## + + # Make sure we are setting up a Windows machine. + # Linux and Mac do not use passwords in services and cron jobs; no need to link. + os_field = next((x for x in machine_record.fields if x.label == "operatingSystem"), None) + if os_field is None: + logger.error(f"Cannot find the operating system field in this record.") + return + os_type = None + if len(os_field.value) > 0: + os_type = os_field.value[0] + if os_type is None: + logger.error(f"The operating system field of the machine record is blank.") + return + if os_type != "windows": + logger.error(f"The operating system is not Windows. " + "PAM can only rotate the services and scheduled task password on Windows.") + return + + # Get the machine service vertex. + # If it doesn't exist, create one. + machine_vertex = user_service.get_record_link(machine_record.record_uid) + if machine_vertex is None: + machine_vertex = user_service.dag.add_vertex( + uid=machine_record.record_uid, + name=machine_record.title, + vertex_type=RefType.PAM_MACHINE) + + # Get the user service vertex. + # If it doesn't exist, create one. + user_vertex = user_service.get_record_link(user_record.record_uid) + if user_vertex is None: + user_vertex = user_service.dag.add_vertex( + uid=user_record.record_uid, + name=user_record.title, + vertex_type=RefType.PAM_USER) + + # Get the existing service ACL and set the proper attribute. + acl = user_service.get_acl(machine_vertex.uid, user_vertex.uid) + if acl is None: + acl = ServiceAcl() + if rel_type == "service": + acl.is_service = True + elif rel_type == "task": + acl.is_task = True + else: + acl.is_iis_pool = True + + # Make sure the machine has a LINK connection to the configuration. + if not user_service.dag.get_root.has(machine_vertex): + user_service.belongs_to(gateway_context.configuration_uid, machine_vertex.uid) + + # Add our new ACL edge between the machine and the yser. + user_service.belongs_to(machine_vertex.uid, user_vertex.uid, acl=acl) + + user_service.save() + + if rel_type == "service": + logger.info( + f"Success: Services running on this machine, using this user, will be updated and restarted after " + "password rotation." + ) + elif rel_type == "task": + logger.info( + f"Success: Scheduled tasks running on this machine, using this user, will be updated after " + "password rotation." + ) + else: + logger.info( + f"Success: IIS pools running on this machine, using this user, will be updated after " + "password rotation." + ) + + +class PAMActionServiceRemoveCommand(PAMGatewayActionDiscoverCommandBase): + + def __init__(self): + parser = argparse.ArgumentParser(prog='pam-action-service-remove') + PAMActionServiceRemoveCommand.add_arguments_to_parser(parser) + super().__init__(parser) + + @staticmethod + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID') + parser.add_argument('--configuration-uid', '-c', required=False, dest='configuration_uid', + action='store', help='PAM configuration UID, if gateway has multiple.') + parser.add_argument('--machine-uid', '-m', required=True, dest='machine_uid', action='store', + help='The UID of the Windows Machine record') + parser.add_argument('--user-uid', '-u', required=True, dest='user_uid', action='store', + help='The UID of the User record') + parser.add_argument('--type', '-t', required=True, choices=['service', 'task', 'iis'], dest='type', + action='store', help='Relationship to remove [service, task, iis]') + + def execute(self, context: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + machine_uid = kwargs.get("machine_uid") + user_uid = kwargs.get("user_uid") + rel_type = kwargs.get("type") + + logger.info("") + if not context.vault: + raise base.CommandError("Vault not found. Login to initialize the vault.") + vault = context.vault + + gateway_context = GatewayContext.from_gateway(vault=vault, + gateway=gateway, + configuration_uid=kwargs.get('configuration_uid')) + if gateway_context is None: + logger.error(f"Could not find the gateway configuration for {gateway}.") + return + + if gateway_context is None: + logger.error(f"Cannot get gateway information. Gateway may not be up.") + return + + user_service = UserService(record=gateway_context.configuration, context=context, fail_on_corrupt=False, + agent=f"Cmdr/{__version__}") + + machine_record = vault.vault_data.load_record(machine_uid) + if machine_record is None: + logger.error(f"The machine record does not exists.") + return + + if machine_record.record_type != PAM_MACHINE: + logger.error(f"The machine record is not a PAM Machine.") + return + + user_record = vault.vault_data.load_record(user_uid) + if user_record is None: + logger.error(f"The user record does not exists.") + return + + if user_record.record_type != PAM_USER: + logger.error(f"The user record is not a PAM User.") + return + + machine_vertex = user_service.get_record_link(machine_record.record_uid) + if machine_vertex is None: + logger.error(f"The machine does not exist in the mapping.") + return + + user_vertex = user_service.get_record_link(user_record.record_uid) + if user_vertex is None: + logger.error(f"The user does not exist in the mapping.") + return + + acl = user_service.get_acl(machine_vertex.uid, user_vertex.uid) + if acl is None: + logger.error(f"The user did not control any services, scheduled tasks, or IIS pools on the machine.") + return + + if rel_type == "service": + acl.is_service = False + elif rel_type == "task": + acl.is_task = False + else: + acl.is_iis_pool = False + + if not user_service.dag.get_root.has(machine_vertex): + user_service.belongs_to(gateway_context.configuration_uid, machine_vertex.uid) + + user_service.belongs_to(machine_vertex.uid, user_vertex.uid, acl=acl) + user_service.save() + + if rel_type == "service": + logger.info( + f"Success: Services running on this machine will no longer have their password changed when this " + "user's password is rotated." + ) + elif rel_type == "task": + logger.info( + f"Success: Scheduled tasks running on this machine will no longer have their password changed " + "when this user's password is rotated." + ) + else: + logger.info( + f"Success: IIP pools running on this machine will no longer have their password changed " + "when this user's password is rotated." + ) diff --git a/keepercli-package/src/keepercli/commands/secrets_manager.py b/keepercli-package/src/keepercli/commands/secrets_manager.py index 25e2f41e..f7a6536f 100644 --- a/keepercli-package/src/keepercli/commands/secrets_manager.py +++ b/keepercli-package/src/keepercli/commands/secrets_manager.py @@ -70,7 +70,7 @@ def add_arguments_to_parser(parser: argparse.ArgumentParser): '-f', '--force', dest='force', action='store_true', help='Force add or remove app' ) parser.add_argument( - '--email', action='store', type=str, dest='email', help='Email of user to grant / remove application access to / from' + '--email', '-e', action='store', type=str, dest='email', help='Email of user to grant / remove application access to / from' ) parser.add_argument( '--admin', action='store_true', help='Allow share recipient to manage application' @@ -385,7 +385,8 @@ def add_client( first_access_expire_duration_ms=first_access_expire_duration_ms, access_expire_in_ms=access_expire_in_ms, master_key=master_key, - server=server + server=server, + client_type=GENERAL ) tokens.append(token_data['token_info']) diff --git a/keepercli-package/src/keepercli/commands/shares.py b/keepercli-package/src/keepercli/commands/shares.py index 2794969e..5976b429 100644 --- a/keepercli-package/src/keepercli/commands/shares.py +++ b/keepercli-package/src/keepercli/commands/shares.py @@ -132,7 +132,6 @@ def execute(self, context: KeeperParams, **kwargs) -> None: if not context.vault: raise ValueError("Vault is not initialized.") vault = context.vault - uid_or_name = kwargs.get('record') if not uid_or_name: return self.get_parser().print_help() @@ -223,7 +222,6 @@ def get_contact(user, contacts): return None - class ShareFolderCommand(base.ArgparseCommand): def __init__(self): self.parser = argparse.ArgumentParser( diff --git a/keepercli-package/src/keepercli/helpers/email_utils.py b/keepercli-package/src/keepercli/helpers/email_utils.py new file mode 100644 index 00000000..1a39cc42 --- /dev/null +++ b/keepercli-package/src/keepercli/helpers/email_utils.py @@ -0,0 +1,1380 @@ +from abc import ABC, abstractmethod +import os +from typing import Any, Dict, List, Optional +from dataclasses import dataclass, field +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +import smtplib +import ssl +import logging + +import keepersdk +from keepersdk.vault import vault_online, vault_record, vault_extensions, record_management + +from .. import api + +logger = api.get_logger() + + +def get_installation_method(): + """ + Detect Keeper Commander installation method. + + Returns: + str: 'binary' (PyInstaller frozen), 'pip' (installed via pip), or 'source' (development) + """ + + # Check if running as PyInstaller binary + if getattr(sys, 'frozen', False): + return 'binary' + + # Check if installed via pip + location = keepersdk.__file__ + + if 'site-packages' in location or 'dist-packages' in location: + return 'pip' + + # Running from source + return 'source' + + +def check_provider_dependencies(provider: str) -> tuple: + """ + Check if dependencies for a provider are available. + + Args: + provider: Provider name (smtp, ses, sendgrid, gmail-oauth, microsoft-oauth) + + Returns: + tuple: (dependencies_available: bool, error_message: str) + """ + if provider == 'smtp': + # SMTP uses standard library, always works + return (True, '') + + installation_method = get_installation_method() + + # Binary installations only support SMTP + if installation_method == 'binary': + return ( + False, + f'{provider} is not available in the binary installation.\n' + f'\n' + f'To use this provider, you must switch to the PyPI version:\n' + f' 1. Uninstall the binary version\n' + f' 2. Install via pip with email dependencies:\n' + f' pip install keepercommander[email]\n' + f'\n' + f'The binary version only supports SMTP for email functionality.' + ) + + # Check for required packages on pip/source installations + if provider == 'ses': + try: + import boto3 + return (True, '') + except ImportError: + return ( + False, + 'AWS SES requires additional dependencies.\n' + 'Install with:\n' + ' pip install keepercommander[email-ses]\n' + ' # or install all email providers:\n' + ' pip install keepercommander[email]' + ) + + elif provider == 'sendgrid': + try: + import sendgrid + return (True, '') + except ImportError: + return ( + False, + 'SendGrid requires additional dependencies.\n' + 'Install with:\n' + ' pip install keepercommander[email-sendgrid]\n' + ' # or install all email providers:\n' + ' pip install keepercommander[email]' + ) + + elif provider == 'gmail-oauth': + try: + import google.auth + import googleapiclient + return (True, '') + except ImportError: + return ( + False, + 'Gmail OAuth requires additional dependencies.\n' + 'Install with:\n' + ' pip install keepercommander[email-gmail-oauth]\n' + ' # or install all email providers:\n' + ' pip install keepercommander[email]' + ) + + elif provider == 'microsoft-oauth': + try: + import msal + return (True, '') + except ImportError: + return ( + False, + 'Microsoft OAuth requires additional dependencies.\n' + 'Install with:\n' + ' pip install keepercommander[email-microsoft-oauth]\n' + ' # or install all email providers:\n' + ' pip install keepercommander[email]' + ) + + return (True, '') + + +@dataclass +class EmailConfig: + """ + Email configuration for sending emails via various providers. + + Stored as Keeper records (login type) with encrypted credentials. + See ADR Decision 1 in ONBOARDING_FEATURE_IMPLEMENTATION_PLAN.md + + Supported providers: + - smtp: Standard SMTP with username/password + - ses: AWS Simple Email Service + - sendgrid: SendGrid API + - gmail-oauth: Gmail with OAuth 2.0 (uses Gmail API) + - microsoft-oauth: Microsoft 365/Outlook with OAuth 2.0 (uses Graph API) + """ + record_uid: str + name: str + provider: str # 'smtp', 'ses', 'sendgrid', 'gmail-oauth', 'microsoft-oauth' + from_address: str + from_name: str = "Keeper Commander" + + # SMTP-specific + smtp_host: Optional[str] = None + smtp_port: int = 587 + smtp_username: Optional[str] = None + smtp_password: Optional[str] = None + smtp_use_tls: bool = True + smtp_use_ssl: bool = False + + # AWS SES-specific + aws_region: Optional[str] = None + aws_access_key: Optional[str] = None + aws_secret_key: Optional[str] = None + + # SendGrid-specific + sendgrid_api_key: Optional[str] = None + + # OAuth-specific (gmail-oauth, microsoft-oauth) + oauth_client_id: Optional[str] = None + oauth_client_secret: Optional[str] = None + oauth_access_token: Optional[str] = None + oauth_refresh_token: Optional[str] = None + oauth_token_expiry: Optional[str] = None # ISO 8601 format + oauth_scopes: Optional[List[str]] = None + oauth_tenant_id: Optional[str] = None # For Microsoft 365 (can be 'common', 'organizations', or specific tenant ID) + + # Additional metadata + custom_fields: Dict[str, Any] = field(default_factory=dict) + + # Internal flag to track OAuth token updates (not stored in Keeper) + _oauth_tokens_updated: bool = field(default=False, init=False) + + def validate(self) -> List[str]: + """ + Validate email configuration completeness. + + Returns: + List of validation error messages (empty if valid) + """ + errors = [] + + if not self.provider: + errors.append("Provider is required") + + if not self.from_address: + errors.append("From address is required") + + if self.provider == 'smtp': + if not self.smtp_host: + errors.append("SMTP host is required") + if not self.smtp_username: + errors.append("SMTP username is required") + if not self.smtp_password: + errors.append("SMTP password is required") + + elif self.provider == 'ses': + if not self.aws_region: + errors.append("AWS region is required for SES") + if not self.aws_access_key: + errors.append("AWS access key is required for SES") + if not self.aws_secret_key: + errors.append("AWS secret key is required for SES") + + elif self.provider == 'sendgrid': + if not self.sendgrid_api_key: + errors.append("SendGrid API key is required") + + elif self.provider in ('gmail-oauth', 'microsoft-oauth'): + # OAuth providers require either interactive auth OR manual token entry + if not self.oauth_client_id: + errors.append(f"OAuth client ID is required for {self.provider}") + if not self.oauth_client_secret: + errors.append(f"OAuth client secret is required for {self.provider}") + + # If no access token, we'll do interactive OAuth flow + # If access token provided, we should also have refresh token + if self.oauth_access_token and not self.oauth_refresh_token: + errors.append("OAuth refresh token is required when access token is provided") + + # Microsoft requires tenant ID + if self.provider == 'microsoft-oauth' and not self.oauth_tenant_id: + errors.append("OAuth tenant ID is required for Microsoft (use 'common' for multi-tenant)") + + else: + errors.append(f"Unknown provider: {self.provider}") + + return errors + + def is_oauth_provider(self) -> bool: + """Check if this config uses OAuth authentication.""" + return self.provider in ('gmail-oauth', 'microsoft-oauth') + + def tokens_need_refresh(self) -> bool: + """ + Check if OAuth tokens need to be refreshed. + + Returns: + True if tokens are expired or will expire soon (within 5 minutes) + """ + if not self.is_oauth_provider(): + return False + + if not self.oauth_token_expiry: + # No expiry set, assume tokens are valid + return False + + try: + from datetime import datetime, timedelta, timezone + expiry = datetime.fromisoformat(self.oauth_token_expiry.replace('Z', '+00:00')) + now = datetime.now(timezone.utc) + # Refresh if expired or expiring within 5 minutes + return expiry <= now + timedelta(minutes=5) + except Exception: + # If we can't parse the expiry, assume we need to refresh + return True + + +class EmailProvider(ABC): + """ + Abstract base class for email providers. + + Each provider implements send() method for their specific API/protocol. + """ + + def __init__(self, config: EmailConfig): + self.config = config + validation_errors = config.validate() + if validation_errors: + raise ValueError(f"Invalid email configuration: {', '.join(validation_errors)}") + + @abstractmethod + def send(self, to: str, subject: str, body: str, html: bool = False) -> bool: + """ + Send email via provider. + + Args: + to: Recipient email address + subject: Email subject + body: Email body (plain text or HTML) + html: True if body is HTML + + Returns: + True if sent successfully, False otherwise + + Raises: + Exception: If send fails with unrecoverable error + """ + pass + + @abstractmethod + def test_connection(self) -> bool: + """ + Test connection to email provider. + + Returns: + True if connection successful, False otherwise + """ + pass + + +class SMTPEmailProvider(EmailProvider): + """ + SMTP email provider implementation. + + Supports standard SMTP with TLS/SSL for Gmail, Office 365, and other SMTP servers. + """ + + def __init__(self, config: EmailConfig): + super().__init__(config) + self._connection = None + + def send(self, to: str, subject: str, body: str, html: bool = False) -> bool: + """ + Send email via SMTP. + + Args: + to: Recipient email address + subject: Email subject + body: Email body + html: True if body is HTML + + Returns: + True if sent successfully + + Raises: + smtplib.SMTPException: If SMTP operation fails + """ + try: + # Create message + msg = MIMEMultipart('alternative') + msg['Subject'] = subject + msg['From'] = f"{self.config.from_name} <{self.config.from_address}>" + msg['To'] = to + + # Attach body + if html: + part = MIMEText(body, 'html') + else: + part = MIMEText(body, 'plain') + msg.attach(part) + + # Connect and send + if self.config.smtp_use_ssl: + # Use SMTP_SSL for port 465 + context = ssl.create_default_context() + with smtplib.SMTP_SSL( + self.config.smtp_host, + self.config.smtp_port, + context=context + ) as server: + server.login(self.config.smtp_username, self.config.smtp_password) + server.send_message(msg) + else: + # Use SMTP with STARTTLS for port 587 + with smtplib.SMTP( + self.config.smtp_host, + self.config.smtp_port, + timeout=30 + ) as server: + server.ehlo() + if self.config.smtp_use_tls: + context = ssl.create_default_context() + server.starttls(context=context) + server.ehlo() + server.login(self.config.smtp_username, self.config.smtp_password) + server.send_message(msg) + + logging.info(f"[EMAIL] SMTP email sent to {to} via {self.config.smtp_host}") + return True + + except smtplib.SMTPAuthenticationError as e: + logging.error(f"[EMAIL] SMTP authentication failed: {e}") + raise + except smtplib.SMTPException as e: + logging.error(f"[EMAIL] SMTP error: {e}") + raise + except Exception as e: + logging.error(f"[EMAIL] Unexpected error sending email: {e}") + raise + + def test_connection(self) -> bool: + """ + Test SMTP connection and authentication. + + Returns: + True if connection successful + """ + try: + if self.config.smtp_use_ssl: + context = ssl.create_default_context() + with smtplib.SMTP_SSL( + self.config.smtp_host, + self.config.smtp_port, + context=context, + timeout=10 + ) as server: + server.login(self.config.smtp_username, self.config.smtp_password) + else: + with smtplib.SMTP( + self.config.smtp_host, + self.config.smtp_port, + timeout=10 + ) as server: + server.ehlo() + if self.config.smtp_use_tls: + context = ssl.create_default_context() + server.starttls(context=context) + server.ehlo() + server.login(self.config.smtp_username, self.config.smtp_password) + + logging.info(f"[EMAIL] SMTP connection test successful: {self.config.smtp_host}") + return True + + except Exception as e: + logging.error(f"[EMAIL] SMTP connection test failed: {e}") + return False + + +class SendGridEmailProvider(EmailProvider): + """ + SendGrid email provider implementation. + + Uses SendGrid HTTP API for sending emails. + """ + + def __init__(self, config: EmailConfig): + super().__init__(config) + # Import here to avoid dependency if not using SendGrid + try: + from sendgrid import SendGridAPIClient + from sendgrid.helpers.mail import Mail + self.SendGridAPIClient = SendGridAPIClient + self.Mail = Mail + except ImportError: + _, error_message = check_provider_dependencies('sendgrid') + raise ImportError(error_message) + + def send(self, to: str, subject: str, body: str, html: bool = False) -> bool: + """ + Send email via SendGrid API. + + Args: + to: Recipient email address + subject: Email subject + body: Email body + html: True if body is HTML + + Returns: + True if sent successfully + """ + try: + message = self.Mail( + from_email=(self.config.from_address, self.config.from_name), + to_emails=to, + subject=subject, + html_content=body if html else None, + plain_text_content=body if not html else None + ) + + sg = self.SendGridAPIClient(self.config.sendgrid_api_key) + response = sg.send(message) + + if response.status_code in (200, 201, 202): + logging.info(f"[EMAIL] SendGrid email sent to {to}") + return True + else: + logging.error(f"[EMAIL] SendGrid returned status {response.status_code}") + return False + + except Exception as e: + logging.error(f"[EMAIL] SendGrid error: {e}") + raise + + def test_connection(self) -> bool: + """ + Test SendGrid API connection. + + Returns: + True if API key is valid + """ + try: + # SendGrid doesn't have a dedicated test endpoint + # We can verify the API key format and try to initialize the client + sg = self.SendGridAPIClient(self.config.sendgrid_api_key) + + # If we get here without exception, API key format is valid + # Note: This doesn't guarantee the key is active, but it's the best we can do + logging.info("[EMAIL] SendGrid API client initialized successfully") + return True + + except Exception as e: + logging.error(f"[EMAIL] SendGrid connection test failed: {e}") + return False + + +class SESEmailProvider(EmailProvider): + """ + AWS SES email provider implementation. + + Uses boto3 to send emails via Amazon Simple Email Service. + """ + + def __init__(self, config: EmailConfig): + super().__init__(config) + # Import here to avoid dependency if not using SES + try: + import boto3 + from botocore.exceptions import ClientError + self.boto3 = boto3 + self.ClientError = ClientError + except ImportError: + _, error_message = check_provider_dependencies('ses') + raise ImportError(error_message) + + # Initialize SES client + self.ses_client = self.boto3.client( + 'ses', + region_name=self.config.aws_region, + aws_access_key_id=self.config.aws_access_key, + aws_secret_access_key=self.config.aws_secret_key + ) + + def send(self, to: str, subject: str, body: str, html: bool = False) -> bool: + """ + Send email via AWS SES. + + Args: + to: Recipient email address + subject: Email subject + body: Email body + html: True if body is HTML + + Returns: + True if sent successfully + """ + try: + if html: + body_part = {'Html': {'Charset': 'UTF-8', 'Data': body}} + else: + body_part = {'Text': {'Charset': 'UTF-8', 'Data': body}} + + response = self.ses_client.send_email( + Source=f"{self.config.from_name} <{self.config.from_address}>", + Destination={'ToAddresses': [to]}, + Message={ + 'Subject': {'Charset': 'UTF-8', 'Data': subject}, + 'Body': body_part + } + ) + + message_id = response.get('MessageId') + logging.info(f"[EMAIL] SES email sent to {to}, MessageId: {message_id}") + return True + + except self.ClientError as e: + error_code = e.response['Error']['Code'] + error_message = e.response['Error']['Message'] + + if error_code == 'MessageRejected': + logging.error(f"[EMAIL] SES rejected message: {error_message}") + elif error_code == 'ConfigurationSetDoesNotExist': + logging.error(f"[EMAIL] SES configuration error: {error_message}") + else: + logging.error(f"[EMAIL] SES error ({error_code}): {error_message}") + + raise + + except Exception as e: + logging.error(f"[EMAIL] SES unexpected error: {e}") + raise + + def test_connection(self) -> bool: + """ + Test AWS SES connection and verify email address. + + Returns: + True if connection successful and sender email verified + """ + try: + # Check if sender email is verified in SES + response = self.ses_client.get_identity_verification_attributes( + Identities=[self.config.from_address] + ) + + attributes = response.get('VerificationAttributes', {}) + status = attributes.get(self.config.from_address, {}).get('VerificationStatus') + + if status == 'Success': + logging.info(f"[EMAIL] SES connection test successful, {self.config.from_address} is verified") + return True + else: + logging.warning( + f"[EMAIL] SES email {self.config.from_address} not verified. " + f"Status: {status}. Emails may fail to send." + ) + return False + + except self.ClientError as e: + logging.error(f"[EMAIL] SES connection test failed: {e}") + return False + + except Exception as e: + logging.error(f"[EMAIL] SES unexpected error: {e}") + return False + + +class GmailOAuthProvider(EmailProvider): + """ + Gmail OAuth email provider implementation. + + Uses Gmail API with OAuth 2.0 authentication instead of SMTP. + Automatically refreshes expired tokens. + """ + + def __init__(self, config: EmailConfig): + """ + Initialize Gmail OAuth provider. + + Args: + config: EmailConfig with OAuth credentials + """ + super().__init__(config) + + # Import Gmail API dependencies + try: + from google.auth.transport.requests import Request + from google.oauth2.credentials import Credentials + from googleapiclient.discovery import build + + self.Request = Request + self.Credentials = Credentials + self.build = build + except ImportError as e: + _, error_message = check_provider_dependencies('gmail-oauth') + raise ImportError(error_message) from e + + self.credentials = self._load_credentials() + self.service = None + + def _load_credentials(self): + """Load OAuth credentials from config.""" + from datetime import datetime, timezone + + if not self.config.oauth_access_token: + raise ValueError("Gmail OAuth access token is required") + + # Parse token expiry + token_expiry = None + if self.config.oauth_token_expiry: + try: + # Parse as timezone-aware datetime + expiry_aware = datetime.fromisoformat( + self.config.oauth_token_expiry.replace('Z', '+00:00') + ) + # Convert to naive UTC datetime (Google's library expects naive datetimes) + token_expiry = expiry_aware.replace(tzinfo=None) + except Exception: + pass + + # Create credentials object + creds = self.Credentials( + token=self.config.oauth_access_token, + refresh_token=self.config.oauth_refresh_token, + token_uri='https://oauth2.googleapis.com/token', + client_id=self.config.oauth_client_id, + client_secret=self.config.oauth_client_secret, + scopes=['https://www.googleapis.com/auth/gmail.send'] + ) + + # Set expiry if available (as naive UTC datetime) + if token_expiry: + creds.expiry = token_expiry + + return creds + + def _refresh_if_expired(self): + """Refresh tokens if expired.""" + if self.credentials.expired or not self.credentials.valid: + logging.info("[EMAIL] Gmail OAuth tokens expired, refreshing...") + self.credentials.refresh(self.Request()) + + # Update config with new tokens + self.config.oauth_access_token = self.credentials.token + if self.credentials.refresh_token: + self.config.oauth_refresh_token = self.credentials.refresh_token + if self.credentials.expiry: + self.config.oauth_token_expiry = self.credentials.expiry.isoformat() + + # Mark tokens as updated so caller can persist them + self.config._oauth_tokens_updated = True + + logging.info("[EMAIL] Gmail OAuth tokens refreshed successfully") + + def _get_service(self): + """Get or create Gmail API service.""" + if not self.service: + self._refresh_if_expired() + self.service = self.build('gmail', 'v1', credentials=self.credentials) + return self.service + + def send(self, to: str, subject: str, body: str, html: bool = False) -> bool: + """ + Send email via Gmail API. + + Args: + to: Recipient email address + subject: Email subject + body: Email body (HTML or plain text) + html: True if body is HTML + + Returns: + True if sent successfully + """ + try: + import base64 + from email.mime.text import MIMEText + + # Create message + message = MIMEText(body, 'html' if html else 'plain') + message['to'] = to + message['from'] = f"{self.config.from_name} <{self.config.from_address}>" + message['subject'] = subject + + # Encode message + raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode('utf-8') + + # Send via Gmail API + service = self._get_service() + service.users().messages().send( + userId='me', + body={'raw': raw_message} + ).execute() + + logging.info(f"[EMAIL] Gmail OAuth email sent to {to}") + return True + + except Exception as e: + logging.error(f"[EMAIL] Gmail OAuth error: {e}") + raise + + def test_connection(self) -> bool: + """ + Test Gmail OAuth connection. + + Returns: + True if credentials are valid + """ + try: + # Refresh tokens if needed + self._refresh_if_expired() + + # Verify credentials are valid by checking they're not expired + # Note: gmail.send scope doesn't allow reading profile/labels + # so we just verify the credentials object is valid + if self.credentials and self.credentials.valid: + logging.info(f"[EMAIL] Gmail OAuth connection successful: {self.config.from_address}") + return True + else: + logging.error("[EMAIL] Gmail OAuth credentials are invalid") + return False + + except Exception as e: + logging.error(f"[EMAIL] Gmail OAuth connection test failed: {e}") + return False + + +class MicrosoftOAuthProvider(EmailProvider): + """ + Microsoft OAuth email provider implementation. + + Uses Microsoft Graph API with OAuth 2.0 authentication. + Supports Microsoft 365, Outlook.com, and organizational accounts. + """ + + def __init__(self, config: EmailConfig): + """ + Initialize Microsoft OAuth provider. + + Args: + config: EmailConfig with OAuth credentials and tenant_id + """ + super().__init__(config) + + # Import msal dependency + try: + import msal + self.msal = msal + except ImportError as e: + _, error_message = check_provider_dependencies('microsoft-oauth') + raise ImportError(error_message) from e + + if not self.config.oauth_tenant_id: + raise ValueError("Microsoft OAuth requires tenant_id (use 'common' for multi-tenant)") + + # Build MSAL confidential client app + self.app = self._build_msal_app() + self.token_cache = {} + + def _build_msal_app(self): + """Build MSAL confidential client application.""" + authority = f"https://login.microsoftonline.com/{self.config.oauth_tenant_id}" + + return self.msal.ConfidentialClientApplication( + client_id=self.config.oauth_client_id, + client_credential=self.config.oauth_client_secret, + authority=authority + ) + + def _get_access_token(self) -> str: + """ + Get valid access token, refreshing if necessary. + + Returns: + Valid access token + """ + # Check if we have cached token and it's still valid + if self.config.oauth_access_token and not self.config.tokens_need_refresh(): + return self.config.oauth_access_token + + # Need to refresh token + if not self.config.oauth_refresh_token: + raise ValueError("OAuth refresh token is required to refresh access token") + + logging.info("[EMAIL] Microsoft OAuth tokens expired, refreshing...") + + # Acquire token by refresh token + result = self.app.acquire_token_by_refresh_token( + refresh_token=self.config.oauth_refresh_token, + scopes=['https://graph.microsoft.com/Mail.Send'] + ) + + if 'access_token' in result: + # Update config with new tokens + self.config.oauth_access_token = result['access_token'] + + # Update refresh token if new one provided + if 'refresh_token' in result: + self.config.oauth_refresh_token = result['refresh_token'] + + # Calculate and update expiry + if 'expires_in' in result: + from datetime import datetime, timedelta, timezone + expiry = datetime.now(timezone.utc) + timedelta(seconds=result['expires_in']) + self.config.oauth_token_expiry = expiry.isoformat() + + # Mark tokens as updated so caller can persist them + self.config._oauth_tokens_updated = True + + logging.info("[EMAIL] Microsoft OAuth tokens refreshed successfully") + return result['access_token'] + else: + error = result.get('error', 'Unknown error') + error_desc = result.get('error_description', '') + raise Exception(f"Failed to refresh Microsoft OAuth token: {error} - {error_desc}") + + def send(self, to: str, subject: str, body: str, html: bool = False) -> bool: + """ + Send email via Microsoft Graph API. + + Args: + to: Recipient email address + subject: Email subject + body: Email body (HTML or plain text) + html: True if body is HTML + + Returns: + True if sent successfully + """ + try: + import requests + + # Get valid access token + access_token = self._get_access_token() + + # Build Graph API request + url = 'https://graph.microsoft.com/v1.0/me/sendMail' + headers = { + 'Authorization': f'Bearer {access_token}', + 'Content-Type': 'application/json' + } + + # Construct email message + message = { + 'message': { + 'subject': subject, + 'body': { + 'contentType': 'HTML' if html else 'Text', + 'content': body + }, + 'toRecipients': [ + { + 'emailAddress': { + 'address': to + } + } + ], + 'from': { + 'emailAddress': { + 'name': self.config.from_name, + 'address': self.config.from_address + } + } + }, + 'saveToSentItems': 'true' + } + + # Send request + response = requests.post(url, headers=headers, json=message) + + if response.status_code == 202: + logging.info(f"[EMAIL] Microsoft Graph email sent to {to}") + return True + else: + logging.error(f"[EMAIL] Microsoft Graph returned status {response.status_code}: {response.text}") + return False + + except Exception as e: + logging.error(f"[EMAIL] Microsoft OAuth error: {e}") + raise + + def test_connection(self) -> bool: + """ + Test Microsoft OAuth connection. + + Returns: + True if credentials are valid + """ + try: + import requests + + # Get valid access token + access_token = self._get_access_token() + + # Try to get user profile to verify connection + url = 'https://graph.microsoft.com/v1.0/me' + headers = { + 'Authorization': f'Bearer {access_token}' + } + + response = requests.get(url, headers=headers) + + if response.status_code == 200: + data = response.json() + email = data.get('mail') or data.get('userPrincipalName') + logging.info(f"[EMAIL] Microsoft OAuth connection successful: {email}") + return True + else: + logging.error(f"[EMAIL] Microsoft OAuth connection test failed: {response.status_code}") + return False + + except Exception as e: + logging.error(f"[EMAIL] Microsoft OAuth connection test failed: {e}") + return False + + +class EmailSender: + """ + Main email sender class that routes to appropriate provider. + + Usage: + config = EmailConfig(...) + sender = EmailSender(config) + sender.send(to='user@example.com', subject='Test', body='Hello', html=True) + """ + + def __init__(self, config: EmailConfig): + """ + Initialize email sender with configuration. + + Args: + config: EmailConfig object + + Raises: + ValueError: If provider is unknown or config invalid + """ + self.config = config + + # Check provider compatibility with current installation + dependencies_available, error_message = check_provider_dependencies(config.provider) + if not dependencies_available: + raise ValueError(error_message) + + # Create provider instance + provider_map = { + 'smtp': SMTPEmailProvider, + 'sendgrid': SendGridEmailProvider, + 'ses': SESEmailProvider, + 'gmail-oauth': GmailOAuthProvider, + 'microsoft-oauth': MicrosoftOAuthProvider, + } + + provider_class = provider_map.get(config.provider.lower()) + if not provider_class: + raise ValueError( + f"Unknown email provider: {config.provider}. " + f"Supported: {', '.join(provider_map.keys())}" + ) + + self.provider = provider_class(config) + + def send(self, to: str, subject: str, body: str, html: bool = False) -> bool: + """ + Send email via configured provider. + + Args: + to: Recipient email address + subject: Email subject + body: Email body + html: True if body is HTML + + Returns: + True if sent successfully + + Raises: + Exception: If send fails + """ + logging.info(f"[EMAIL] Sending email to {to} via {self.config.provider}") + return self.provider.send(to, subject, body, html) + + def test_connection(self) -> bool: + """ + Test connection to email provider. + + Returns: + True if connection successful + """ + return self.provider.test_connection() + + + +def find_email_config_record(vault: vault_online.VaultOnline, name: str) -> Optional[str]: + """ + Find email config record by name. + + Args: + vault: VaultOnline session + name: Name of email configuration + + Returns: + Record UID if found, None otherwise + """ + for record_uid in vault.vault_data.records: + record = vault.vault_data.load_record(record_uid) + if not isinstance(record, vault_record.TypedRecord): + continue + if record.record_type != 'login': + continue + + # Check if this is an email config by looking for custom field + try: + record_dict = vault_extensions.extract_typed_record_data(record) + custom_fields = record_dict.get('custom', []) + for field in custom_fields: + if field.get('type') == 'text' and field.get('label') == '__email_config__': + if record.title == name: + return record_uid + except: + continue + + return None + + +def load_email_config_from_record(vault: vault_online.VaultOnline, record_uid: str) -> EmailConfig: + """ + Load EmailConfig from a Keeper record. + + Args: + vault: VaultOnline session + record_uid: Record UID + + Returns: + EmailConfig object + + Raises: + CommandError: If record not found or invalid + """ + if record_uid not in vault.vault_data.records: + raise ValueError(f'Record {record_uid} not found') + + record = vault.vault_data.load_record(record_uid) + if not isinstance(record, vault_record.TypedRecord): + raise ValueError(f'Record {record_uid} is not a typed record') + + # Extract record data + record_dict = vault_extensions.extract_typed_record_data(record) + + # Get login/password fields + fields = record_dict.get('fields', []) + login = None + password = None + + for field in fields: + if field.get('type') == 'login': + values = field.get('value', []) + if values: + login = values[0] + elif field.get('type') == 'password': + values = field.get('value', []) + if values: + password = values[0] + + # Get custom fields with provider configuration + custom_fields = record_dict.get('custom', []) + provider_data = {} + + for field in custom_fields: + label = field.get('label', '') + if label.startswith('__email_'): + continue # Skip marker fields + + values = field.get('value', []) + if values: + provider_data[label] = values[0] + + # Build EmailConfig + provider = provider_data.get('provider', 'smtp') + + config = EmailConfig( + record_uid=record_uid, + name=record.title, + provider=provider, + from_address=provider_data.get('from_address', ''), + from_name=provider_data.get('from_name', 'Keeper Commander') + ) + + # Provider-specific fields + if provider == 'smtp': + config.smtp_host = provider_data.get('smtp_host') + config.smtp_port = int(provider_data.get('smtp_port', 587)) + config.smtp_username = login or provider_data.get('smtp_username') + config.smtp_password = password or provider_data.get('smtp_password') + config.smtp_use_tls = provider_data.get('smtp_use_tls', 'true').lower() == 'true' + config.smtp_use_ssl = provider_data.get('smtp_use_ssl', 'false').lower() == 'true' + + elif provider == 'ses': + config.aws_region = provider_data.get('aws_region') + config.aws_access_key = login or provider_data.get('aws_access_key') + config.aws_secret_key = password or provider_data.get('aws_secret_key') + + elif provider == 'sendgrid': + config.sendgrid_api_key = password or provider_data.get('sendgrid_api_key') + + elif provider in ('gmail-oauth', 'microsoft-oauth'): + # OAuth tokens stored in login/password fields + config.oauth_access_token = login + config.oauth_refresh_token = password + + # OAuth configuration from custom fields + config.oauth_client_id = provider_data.get('oauth_client_id') + config.oauth_client_secret = provider_data.get('oauth_client_secret') + config.oauth_token_expiry = provider_data.get('oauth_token_expiry') + + if provider == 'microsoft-oauth': + config.oauth_tenant_id = provider_data.get('oauth_tenant_id', 'common') + + return config + + +def build_onboarding_email( + share_url: str, + custom_message: str, + record_title: str, + expiration: Optional[str] = None +) -> str: + """ + Build HTML email for onboarding with one-time share link. + + Args: + share_url: One-time share URL + custom_message: Custom message from administrator + record_title: Title of the record being shared + expiration: Human-readable expiration time (e.g., "24 hours", "1 day") + + Returns: + HTML email body + + Raises: + FileNotFoundError: If email template not found + """ + # Load template + template = load_email_template('onboarding.html') + + # Prepare variables + expiration_text = ( + f"This link will expire in {expiration}" + if expiration + else "This link will expire after first use" + ) + + # Fill in template + html = template.format( + custom_message=custom_message, + share_url=share_url, + record_title=record_title, + expiration_text=expiration_text + ) + + return html + + +def load_email_template(template_name: str = 'onboarding.html') -> str: + """ + Load email template from resources directory. + + Args: + template_name: Name of template file (default: onboarding.html) + + Returns: + Template content as string + + Raises: + FileNotFoundError: If template file doesn't exist + """ + # Get path to template file + module_dir = os.path.dirname(os.path.abspath(__file__)) + template_path = os.path.join(module_dir, 'resources', 'email_templates', template_name) + + if not os.path.exists(template_path): + raise FileNotFoundError(f"Email template not found: {template_path}") + + with open(template_path, 'r', encoding='utf-8') as f: + return f.read() + + +def validate_email_provider_dependencies(provider: str) -> tuple[bool, Optional[str]]: + """ + Validate that required dependencies for an email provider are installed. + + This function checks if dependencies are available WITHOUT creating the provider instance, + allowing early validation before performing operations like password rotation. + + Args: + provider: Email provider name ('smtp', 'sendgrid', 'ses', 'gmail-oauth', 'microsoft-oauth') + + Returns: + Tuple of (is_valid, error_message): + - is_valid: True if dependencies are available, False otherwise + - error_message: None if valid, otherwise contains install instructions + + Examples: + >>> valid, error = validate_email_provider_dependencies('gmail-oauth') + >>> if not valid: + ... print(error) + Gmail OAuth requires google-api-python-client and related libraries. + Install with: pip install keepercommander[email-gmail-oauth] + """ + provider = provider.lower() + + # SMTP uses Python built-ins, no extra dependencies needed + if provider == 'smtp': + return True, None + + # SendGrid + if provider == 'sendgrid': + try: + import sendgrid # noqa: F401 + return True, None + except ImportError: + return False, ( + "SendGrid email provider requires the 'sendgrid' library.\n" + "Install with: pip install keepercommander[email-sendgrid]\n" + "Or install manually: pip install sendgrid>=6.10.0" + ) + + # AWS SES + if provider == 'ses': + try: + import boto3 # noqa: F401 + return True, None + except ImportError: + return False, ( + "AWS SES email provider requires the 'boto3' library.\n" + "Install with: pip install keepercommander[email-ses]\n" + "Or install manually: pip install boto3>=1.26.0" + ) + + # Gmail OAuth + if provider == 'gmail-oauth': + try: + import google.auth # noqa: F401 + import google.auth.transport.requests # noqa: F401 + import google.oauth2.credentials # noqa: F401 + import googleapiclient.discovery # noqa: F401 + return True, None + except ImportError: + return False, ( + "Gmail OAuth email provider requires Google API libraries.\n" + "Install with: pip install keepercommander[email-gmail-oauth]\n" + "Or install manually: pip install google-api-python-client google-auth google-auth-oauthlib google-auth-httplib2" + ) + + # Microsoft OAuth + if provider == 'microsoft-oauth': + try: + import msal # noqa: F401 + return True, None + except ImportError: + return False, ( + "Microsoft OAuth email provider requires the 'msal' library.\n" + "Install with: pip install keepercommander[email-microsoft-oauth]\n" + "Or install manually: pip install msal>=1.20.0" + ) + + # Unknown provider + return False, ( + f"Unknown email provider: {provider}\n" + f"Supported providers: smtp, sendgrid, ses, gmail-oauth, microsoft-oauth" + ) + + +def update_oauth_tokens_in_record(vault: vault_online.VaultOnline, record_uid: str, + access_token: str, refresh_token: str, + token_expiry: str) -> None: + """ + Update OAuth tokens in email config record after automatic refresh. + + This function is called by OAuth email providers (GmailOAuthProvider, + MicrosoftOAuthProvider) after they automatically refresh expired tokens. + It updates the Keeper record to persist the new tokens. + + Args: + vault: VaultOnline session + record_uid: UID of the email config record to update + access_token: New OAuth access token + refresh_token: New OAuth refresh token (may be same as old) + token_expiry: New token expiry in ISO 8601 format + + Raises: + CommandError: If record not found or update fails + """ + # Load the record + if record_uid not in vault.vault_data.records: + vault.sync_down(force=True) + + if record_uid not in vault.vault_data.records: + raise ValueError(f'Email configuration record not found: {record_uid}') + + # Load as TypedRecord + record = vault.vault_data.load_record(record_uid) + if not isinstance(record, vault_record.TypedRecord): + raise ValueError(f'Record is not a TypedRecord: {record_uid}') + + # Update token fields (login = access_token, password = refresh_token) + for field in record.fields: + if field.type == 'login': + field.value = [access_token] + elif field.type == 'password': + field.value = [refresh_token] + + # Update token expiry in custom fields + expiry_field_found = False + for field in record.custom: + if field.label == 'oauth_token_expiry': + field.value = [token_expiry] + expiry_field_found = True + break + + # Add expiry field if it doesn't exist + if not expiry_field_found: + record.custom.append(vault.TypedField.new_field('text', token_expiry, 'oauth_token_expiry')) + + # Update the record + record_management.update_record(vault, record) + + # Sync changes + vault.sync_down(force=True) + + logging.debug(f'[EMAIL-CONFIG] Updated OAuth tokens for record: {record_uid}') + diff --git a/keepercli-package/src/keepercli/helpers/gateway_utils.py b/keepercli-package/src/keepercli/helpers/gateway_utils.py index be6194a9..4531bd99 100644 --- a/keepercli-package/src/keepercli/helpers/gateway_utils.py +++ b/keepercli-package/src/keepercli/helpers/gateway_utils.py @@ -1,8 +1,9 @@ from time import time from typing import List +from keepersdk import utils from keepersdk.vault import vault_online -from keepersdk.proto import pam_pb2 +from keepersdk.proto import pam_pb2, enterprise_pb2 from keepersdk.vault import ksm_management @@ -69,7 +70,8 @@ def create_gateway( first_access_expire_duration_ms=first_access_expire_duration_ms, access_expire_in_ms=None, master_key=master_key, - server=vault.keeper_auth.keeper_endpoint.server + server=vault.keeper_auth.keeper_endpoint.server, + client_type=enterprise_pb2.DISCOVERY_AND_ROTATION_CONTROLLER ) return _extract_one_time_token(one_time_token_dict) @@ -109,3 +111,13 @@ def set_gateway_max_instances(vault: vault_online.VaultOnline, gateway_uid: byte rest_endpoint=REST_ENDPOINT_SET_MAX_INSTANCE_COUNT, request=rq ) + + +def find_connected_gateways(all_controllers, identifier): + + found_connected_controller_uid_bytes = next((c for c in all_controllers if (utils.base64_url_encode(c) == identifier)), None) + + if found_connected_controller_uid_bytes: + return found_connected_controller_uid_bytes + else: + return None diff --git a/keepercli-package/src/keepercli/helpers/record_utils.py b/keepercli-package/src/keepercli/helpers/record_utils.py index 614129c9..b5868729 100644 --- a/keepercli-package/src/keepercli/helpers/record_utils.py +++ b/keepercli-package/src/keepercli/helpers/record_utils.py @@ -12,6 +12,8 @@ from keepersdk import crypto, utils from keepersdk.proto.APIRequest_pb2 import AddExternalShareRequest, Device from keepersdk.proto.enterprise_pb2 import GetSharingAdminsRequest, GetSharingAdminsResponse +from keepersdk.proto.router_pb2 import RouterRotationInfo +from keepersdk.proto.pam_pb2 import PAMGenericUidRequest from keepersdk.vault import vault_online, vault_record, vault_types, vault_utils from ..commands.base import CommandError @@ -189,3 +191,13 @@ def resolve_record(context: KeeperParams, name: str) -> str: return uid if record_uid is None: raise CommandError(f'Record not found: {name}') + + +def record_rotation_get(vault: vault_online.VaultOnline, record_uid_bytes: bytes) -> RouterRotationInfo: + + rq = PAMGenericUidRequest() + rq.uid = record_uid_bytes + + rotation_info_rs = vault.keeper_auth.execute_auth_rest(rest_endpoint='pam/get_rotation_info', request=rq, response_type=RouterRotationInfo) + + return rotation_info_rs diff --git a/keepercli-package/src/keepercli/helpers/router_utils.py b/keepercli-package/src/keepercli/helpers/router_utils.py index 873e78ce..9dfb1614 100644 --- a/keepercli-package/src/keepercli/helpers/router_utils.py +++ b/keepercli-package/src/keepercli/helpers/router_utils.py @@ -1,12 +1,23 @@ +from datetime import datetime +import json import logging +import os import google +import requests from typing import Optional -from keepersdk.proto import pam_pb2 +from keepersdk import utils, crypto +from keepersdk.authentication import endpoint +from keepersdk.proto import pam_pb2, router_pb2 from keepersdk.vault import vault_online +from keepersdk.errors import KeeperApiError +from ..helpers import gateway_utils +from ..commands.pam.pam_dto import GatewayAction +from ..params import KeeperParams API_PATH_GET_CONTROLLERS = "get_controllers" +VERIFY_SSL = bool(os.environ.get("VERIFY_SSL", "TRUE") == "TRUE") def router_get_connected_gateways(vault: vault_online.VaultOnline) -> Optional[pam_pb2.PAMOnlineControllers]: @@ -27,3 +38,461 @@ def router_get_connected_gateways(vault: vault_online.VaultOnline) -> Optional[p return pam_online_controllers return None + + +def router_send_action_to_gateway(context: KeeperParams, gateway_action: GatewayAction, message_type, is_streaming, + destination_gateway_uid_str=None, gateway_timeout=15000, transmission_key=None, + encrypted_transmission_key=None, encrypted_session_token=None): + # Default time out how long the response from the Gateway should be + krouter_host = context.auth.keeper_endpoint.get_router_server() + + # 1. Find connected gateway to send action to + try: + router_enterprise_controllers_connected = \ + [x.controllerUid for x in router_get_connected_gateways(context.vault).controllers] + + except requests.exceptions.ConnectionError as errc: + logging.info(f"Looks like router is down. Router URL [{krouter_host}]") + return + except Exception as e: + raise e + + if destination_gateway_uid_str: + # Means that we want to get info for a specific Gateway + + destination_gateway_uid_bytes = utils.base64_url_decode(destination_gateway_uid_str) + + if destination_gateway_uid_bytes not in router_enterprise_controllers_connected: + logging.warning(f"\tThis Gateway currently is not online.") + return + else: + if not router_enterprise_controllers_connected or len(router_enterprise_controllers_connected) == 0: + logging.warning(f"\tNo running or connected Gateways in your enterprise. " + f"Start the Gateway before sending any action to it.") + return + elif len(router_enterprise_controllers_connected) == 1: + destination_gateway_uid_bytes = router_enterprise_controllers_connected[0] + destination_gateway_uid_str = utils.base64_url_encode(destination_gateway_uid_bytes) + else: # There are more than two Gateways connected. Selecting the right one + + if not gateway_action.gateway_destination: + logging.warning(f"There are more than one Gateways running in your enterprise. " + f"Only 'pam action rotate' is able to know " + f"which Gateway should receive a request. Any other commands should have a Gateway specified. " + f"See help for the command you are trying to use. To find connected gateways run action " + f"'pam gateway list' and provide Gateway UID or Gateway Name.") + + return + + destination_gateway_uid_bytes = gateway_utils.find_connected_gateways(router_enterprise_controllers_connected, gateway_action.gateway_destination) + destination_gateway_uid_str = utils.base64_url_encode(destination_gateway_uid_bytes) + + msg_id = gateway_action.conversationId if gateway_action.conversationId else GatewayAction.generate_conversation_id('true') + + rq = router_pb2.RouterControllerMessage() + rq.messageUid = utils.base64_url_decode(msg_id) if isinstance(msg_id, str) else msg_id + rq.controllerUid = destination_gateway_uid_bytes + rq.messageType = message_type + rq.streamResponse = is_streaming + rq.payload = gateway_action.toJSON().encode('utf-8') + rq.timeout = gateway_timeout + + if not transmission_key: + transmission_key = utils.generate_aes_key() + + response = router_send_message_to_gateway( + context=context, + transmission_key=transmission_key, + rq_proto=rq, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token) + + rs_body = response.content + + if type(rs_body) == bytes: + router_response = router_pb2.RouterResponse() + router_response.ParseFromString(rs_body) + + rrc = router_pb2.RouterResponseCode.Name(router_response.responseCode) + if router_response.responseCode == router_pb2.RRC_OK: + logging.debug("Good response...") + + elif router_response.responseCode == router_pb2.RRC_BAD_STATE: + raise Exception(router_response.errorMessage + ' response code: ' + rrc) + + elif router_response.responseCode == router_pb2.RRC_TIMEOUT: + # Router tried to send message to the Controller but the response didn't arrive on time + # ex. if Router is expecting response to be within 3 sec, but the gateway didn't respond within that time + raise Exception(router_response.errorMessage + ' response code: ' + rrc) + + elif router_response.responseCode == router_pb2.RRC_CONTROLLER_DOWN: + # Sent an action to the Controller that is no longer online + raise Exception(router_response.errorMessage + ' response code: ' + rrc) + + else: + raise Exception(router_response.errorMessage + ' response code: ' + rrc) + + + payload_encrypted = router_response.encryptedPayload + if payload_encrypted: + + payload_decrypted = crypto.decrypt_aes_v2(payload_encrypted, transmission_key) + + controller_response = pam_pb2.ControllerResponse() + controller_response.ParseFromString(payload_decrypted) + + gateway_response_payload = json.loads(controller_response.payload) + else: + gateway_response_payload = {} + + return { + 'response': gateway_response_payload + } + + +def router_send_message_to_gateway(context: KeeperParams, transmission_key, rq_proto, + encrypted_transmission_key=None, encrypted_session_token=None): + + krouter_host = context.auth.keeper_endpoint.get_router_server() + + if not encrypted_transmission_key: + server_public_key = endpoint.SERVER_PUBLIC_KEYS[context.auth.keeper_endpoint.server_key_id] + + if context.auth.keeper_endpoint.server_key_id < 7: + encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) + else: + encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) + + encrypted_payload = b'' + + if rq_proto: + encrypted_payload = crypto.encrypt_aes_v2(rq_proto.SerializeToString(), transmission_key) + if not encrypted_session_token: + encrypted_session_token = crypto.encrypt_aes_v2(utils.base64_url_decode(context.auth.auth_context.session_token), transmission_key) + + rs = requests.post( + krouter_host+"/api/user/send_controller_message", + verify=VERIFY_SSL, + + headers={ + 'TransmissionKey': utils.base64_url_encode(encrypted_transmission_key), + 'Authorization': f'KeeperUser {utils.base64_url_encode(encrypted_session_token)}', + }, + data=encrypted_payload if rq_proto else None + ) + + if rs.status_code >= 300: + raise Exception(str(rs.status_code) + ': error: ' + rs.reason + ', message: ' + rs.text) + + return rs + + +def print_router_response(router_response, response_type, original_conversation_id=None, is_verbose=False, gateway_uid=''): + if not router_response: + return + + router_response_response = router_response.get('response') + router_response_response_payload_str = router_response_response.get('payload') + router_response_response_payload_dict = json.loads(router_response_response_payload_str) + + if router_response_response_payload_dict.get('warnings'): + for w in router_response_response_payload_dict.get('warnings'): + if w: + logging.warning(f'{w}') + + if original_conversation_id: + # gateway_response_conversation_id = utils.base64_url_decode(router_response_response_payload_dict.get('conversation_id')).decode("utf-8") + # IDs are either bytes or base64 encoded strings which may be padded + gateway_response_conversation_id = router_response_response_payload_dict.get('conversation_id', None) + oid = (utils.base64_url_decode(original_conversation_id) + if isinstance(original_conversation_id, str) + else original_conversation_id) + gid = (utils.base64_url_decode(gateway_response_conversation_id) + if isinstance(gateway_response_conversation_id, str) + else gateway_response_conversation_id) + + if oid != gid: + logging.error(f"Message ID that was sent to the server [{original_conversation_id}] and the conversation id " + f"received back [{gateway_response_conversation_id}] are different. That probably means that " + f"the gateway sent a wrong response that was not associated with the request.") + + if not (router_response_response_payload_dict.get('is_ok') or router_response_response_payload_dict.get('isOk')): + logging.error(f"{json.dumps(router_response_response_payload_dict, indent=4)}") + return + + if router_response_response_payload_dict.get('isScheduled') or router_response_response_payload_dict.get('is_scheduled'): + conversation_id = router_response_response_payload_dict.get('conversation_id') + + gwinfo = f" --gateway={gateway_uid}" if gateway_uid else "" + logging.info(f"Scheduled action id: {conversation_id}") + logging.info(f"The action has been scheduled, use command 'pam action job-info {conversation_id}{gwinfo}' to get status of the scheduled action") + return + + elif response_type == 'job_info': + job_info = router_response_response_payload_dict.get('data') + exec_response_value = job_info.get('execResponseValue') + exec_response_value_msg = exec_response_value.get('message') if exec_response_value else None + exec_response_value_logs = exec_response_value.get('execLog') if exec_response_value else None + exec_duration = job_info.get('executionDuration') + exec_status = job_info.get('status') + exec_exception = job_info.get('execException') + + logging.info(f'Execution Details\n-------------------------') + + logging.info(f'\tStatus : {job_info.get("reason") if job_info.get("reason") else exec_status}') + + if exec_duration: + logging.info(f'\tDuration : {exec_duration}') + + if exec_response_value_msg: + logging.info(f'\tResponse Message : {exec_response_value_msg}') + + if exec_response_value_logs: + logging.info(f'\tPost-execution scripts logs:') + for el in exec_response_value_logs: + logging.info(f'\t\tscript: {el.get("name")}') + logging.info(f'\t\treturn code: {el.get("return_code")}') + if el.get("stdout"): + logging.info(f'\t\tstdout:\n---\n{el.get("stdout")}\n---') + if el.get("stderr"): + logging.info(f'\t\tstderr:\n---\n{el.get("stderr")}\n---') + logging.info(f'\n') + + if exec_exception: + logging.info(f'\tExecution Exception : {exec_exception}') + + elif response_type == 'gateway_info': + + gateway_info = router_response_response_payload_dict.get('data') + + # Version and Gateway Details + logging.info(f'\nGateway Details') + gateway_config = gateway_info.get('gateway-config', {}) + version_info = gateway_config.get('version', {}) + if version_info.get("current"): + logging.info(f'\tVersion : {version_info.get("current")}') + + # Convert Unix timestamp to readable format + started_time = gateway_config.get("connection_info", {}).get("started") + try: + if started_time: + started_dt = datetime.fromtimestamp(float(started_time)) + local_tz = datetime.now().astimezone().tzinfo + started_str = f"{started_dt.strftime('%Y-%m-%d %H:%M:%S')} {local_tz}" + logging.info(f'\tStarted Time : {started_str}') + except (ValueError, TypeError): + pass + + if gateway_config.get("ws_log_file"): + logging.info(f'\tLogs Location : {gateway_config.get("ws_log_file")}') + + # Environment Info + machine_env = gateway_info.get('machine', {}).get('environment', {}) + if machine_env and machine_env.get('provider'): + logging.info(f'\nEnvironment Details') + logging.info(f'\tProvider : {machine_env.get("provider")}') + if machine_env.get('provider') != 'Local/Other': + if machine_env.get('account_id'): + logging.info(f'\tAccount : {machine_env.get("account_id")}') + if machine_env.get('region'): + logging.info(f'\tRegion : {machine_env.get("region")}') + if machine_env.get('instance_type'): + logging.info(f'\tInstance Type : {machine_env.get("instance_type")}') + + # Machine Details + machine = gateway_info.get('machine', {}) + logging.info(f'\nMachine Details') + + if machine.get("hostname"): + logging.info(f'\tHostname : {machine.get("hostname")}') + if machine.get("ip_address_local") and machine.get("ip_address_local") != "unknown": + logging.info(f'\tIP (Local) : {machine.get("ip_address_local")}') + if machine.get("ip_address_external"): + logging.info(f'\tIP (External) : {machine.get("ip_address_external")}') + + os_info = [] + if machine.get("system"): os_info.append(machine.get("system")) + if machine.get("release"): os_info.append(machine.get("release")) + if os_info: + logging.info(f'\tOperating System : {" ".join(os_info)}') + + memory = machine.get('memory', {}) + if memory.get('free_gb') is not None and memory.get('total_gb') is not None: + logging.info(f'\tMemory : {memory.get("free_gb")}GB free / {memory.get("total_gb")}GB total') + + # Core Package Versions - Extract from installed packages + installed_packages = { + pkg.split('==')[0]: pkg.split('==')[1] + for pkg in machine.get('installed-python-packages', []) + } + + core_packages = [ + ('KDNRM', installed_packages.get('kdnrm')), + ('Keeper GraphSync', installed_packages.get('keeper-dag')), + ('Discovery Common', installed_packages.get('discovery-common')) + ] + + # Only print Core Components section if at least one core package is found + if any(version for _, version in core_packages): + logging.info(f'\nCore Components') + for name, version in core_packages: + if version: # Only print if version is found + logging.info(f'\t{name:<16} : {version}') + + # KSM Details + logging.info(f'\nKSM Application Details') + ksm_app = gateway_info.get('ksm', {}).get('app', {}) + + if ksm_app.get("title"): + logging.info(f'\tTitle : {ksm_app.get("title")}') + if ksm_app.get("records-count") is not None: + logging.info(f'\tRecords Count : {ksm_app.get("records-count")}') + if ksm_app.get("folders-count") is not None: + logging.info(f'\tFolders Count : {ksm_app.get("folders-count")}') + if ksm_app.get("expires-on"): + logging.info(f'\tExpires On : {ksm_app.get("expires-on")}') + logging.info(f'\tWarnings : {ksm_app.get("warnings") or "None"}') + + # Router Details + logging.info(f'\nRouter Connection') + router_conn = gateway_info.get('router', {}).get('connection', {}) + if router_conn.get("base-url"): + logging.info(f'\tURL : {router_conn.get("base-url")}') + router_status = router_conn.get("status", "UNKNOWN").lower() + logging.info(f'\tStatus : {router_status}') + + # PAM Configurations + logging.info(f'\nPAM Configurations Accessible to this Gateway') + pam_configs = gateway_info.get('pam_configurations', []) + if pam_configs: + for idx, config in enumerate(pam_configs, 1): + logging.info(f'\t{idx}. {config}') + else: + logging.info(f'\tNo PAM Configurations found') + + # Additional details for verbose mode + if is_verbose: + logging.info(f'\nAdditional Details') + if machine.get("working-dir"): + logging.info(f'\tWorking Directory : {machine.get("working-dir")}') + if machine.get("package-dir"): + logging.info(f'\tPackage Directory: {machine.get("package-dir")}') + if machine.get("executable"): + logging.info(f'\tPython Executable: {machine.get("executable")}') + + if machine.get('installed-python-packages'): + logging.info(f'\nInstalled Python Packages') + for package in sorted(machine.get('installed-python-packages', [])): + logging.info(f'\t{package}') + + +def get_response_payload(router_response): + + router_response_response = router_response.get('response') + router_response_response_payload_str = router_response_response.get('payload') + router_response_response_payload_dict = json.loads(router_response_response_payload_str) + + return router_response_response_payload_dict + + +def _post_request_to_router(context: KeeperParams, path, rq_proto=None, rs_type=None, method='post', + raw_without_status_check_response=False, query_params=None, transmission_key=None, + encrypted_transmission_key=None, encrypted_session_token=None): + krouter_host = context.auth.keeper_endpoint.get_router_server() + path = '/api/user/' + path + + if not transmission_key: + transmission_key = utils.generate_aes_key() + if not encrypted_transmission_key: + server_public_key = endpoint.SERVER_PUBLIC_KEYS[context.auth.keeper_endpoint.server_key_id] + + if context.auth.keeper_endpoint.server_key_id < 7: + encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) + else: + encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) + + encrypted_payload = b'' + + if rq_proto: + if logging.getLogger().level <= logging.DEBUG: + js = google.protobuf.json_format.MessageToJson(rq_proto) + logging.debug('>>> [GW RQ] %s: %s', path, js) + encrypted_payload = crypto.encrypt_aes_v2(rq_proto.SerializeToString(), transmission_key) + + if not encrypted_session_token: + encrypted_session_token = crypto.encrypt_aes_v2(utils.base64_url_decode(context.auth.auth_context.session_token), transmission_key) + + try: + rs = requests.request(method, + krouter_host + path, + params=query_params, + verify=VERIFY_SSL, + headers={ + 'TransmissionKey': utils.base64_url_encode(encrypted_transmission_key), + 'Authorization': f'KeeperUser {utils.base64_url_encode(encrypted_session_token)}' + }, + data=encrypted_payload if rq_proto else None + ) + except ConnectionError as e: + raise KeeperApiError(-1, f"KRouter is not reachable on '{krouter_host}'. Error: ${e}") + except Exception as ex: + raise ex + + content_type = rs.headers.get('Content-Type') or '' + + if raw_without_status_check_response: + return rs + + if rs.status_code < 400: + if content_type == 'application/json': + return rs.json() + + rs_body = rs.content + if isinstance(rs_body, bytes): + router_response = router_pb2.RouterResponse() + router_response.ParseFromString(rs_body) + + rrc = router_pb2.RouterResponseCode.Name(router_response.responseCode) + if router_response.responseCode != router_pb2.RRC_OK: + raise Exception(router_response.errorMessage + ' Response code: ' + rrc) + + if router_response.encryptedPayload: + payload_encrypted = router_response.encryptedPayload + payload_decrypted = crypto.decrypt_aes_v2(payload_encrypted, transmission_key) + else: + payload_decrypted = None + + if rs_type: + if payload_decrypted: + rs_proto = rs_type() + rs_proto.ParseFromString(payload_decrypted) + if logging.getLogger().level <= logging.DEBUG: + js = google.protobuf.json_format.MessageToJson(rs_proto) + logging.debug('>>> [GW RS] %s: %s', 'get_rotation_schedules', js) + return rs_proto + else: + return None + + return payload_decrypted + + return rs_body + else: + raise KeeperApiError(rs.status_code, rs.text) + + +def router_set_record_rotation_information(context: KeeperParams, proto_request, transmission_key=None, + encrypted_transmission_key=None, encrypted_session_token=None): + rs = _post_request_to_router(context, 'set_record_rotation', proto_request, transmission_key=transmission_key, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token) + + return rs + + +def router_configure_resource(context: KeeperParams, proto_request, transmission_key=None, + encrypted_transmission_key=None, encrypted_session_token=None): + rs = _post_request_to_router(context, 'configure_resource', proto_request, transmission_key=transmission_key, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token) + + return rs diff --git a/keepersdk-package/requirements.txt b/keepersdk-package/requirements.txt index c767bfdf..cba5d495 100644 --- a/keepersdk-package/requirements.txt +++ b/keepersdk-package/requirements.txt @@ -5,3 +5,4 @@ protobuf>=5.28.3 websockets>=13.1 fido2>=2.0.0; python_version>='3.10' email-validator>=2.0.0 +pydantic>=2.6.4; python_version>='3.8' diff --git a/keepersdk-package/src/keepersdk/authentication/yubikey.py b/keepersdk-package/src/keepersdk/authentication/yubikey.py index 83218539..66d3775e 100644 --- a/keepersdk-package/src/keepersdk/authentication/yubikey.py +++ b/keepersdk-package/src/keepersdk/authentication/yubikey.py @@ -1,6 +1,7 @@ import abc import getpass import json +import logging import os import threading from typing import Optional, Any, Dict @@ -11,6 +12,13 @@ from fido2.webauthn import PublicKeyCredentialRequestOptions, UserVerificationRequirement, AuthenticationResponse, PublicKeyCredentialCreationOptions from fido2.ctap2 import Ctap2, ClientPin from .. import utils +from prompt_toolkit import PromptSession + + +prompt_session = None +if os.isatty(0) and os.isatty(1): + prompt_session = PromptSession(multiline=False, complete_while_typing=False) + class IKeeperUserInteraction(abc.ABC): diff --git a/keepersdk-package/src/keepersdk/helpers/config_utils.py b/keepersdk-package/src/keepersdk/helpers/config_utils.py new file mode 100644 index 00000000..87940da8 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/config_utils.py @@ -0,0 +1,37 @@ +from ..proto import pam_pb2 +from ..vault import vault_extensions, vault_online, vault_record +from .. import utils, crypto + +def pam_configuration_create_record_v6(vault: vault_online.VaultOnline, record: vault_record.TypedRecord, folder_uid: str): + if not record.record_uid: + record.record_uid = utils.generate_uid() + + if not record.record_key: + record.record_key = utils.generate_aes_key() + + record_data = vault_extensions.extract_typed_record_data(record) + json_data = record.load_record_data(record_data) + + car = pam_pb2.ConfigurationAddRequest() + car.configurationUid = utils.base64_url_decode(record.record_uid) + car.recordKey = crypto.encrypt_aes_v2(record.record_key, vault.keeper_auth.auth_context.data_key) + car.data = crypto.encrypt_aes_v2(json_data, record.record_key) + + vault.keeper_auth.execute_auth_rest('pam/add_configuration_record', car) + + +def configuration_controller_get(vault: vault_online.VaultOnline, config_uid_bytes: bytes): + """ + Get the Controller UID that has access to the configuration UID + Retrieves a keeper.pam_controller record, from given configuration_uid provided in request. + controller_uid is the UID of the user who has access to the configuration url_safe_str_to_bytes(config_uid) + """ + rq = pam_pb2.PAMGenericUidRequest() + rq.uid = config_uid_bytes + + config_info_rs = vault.keeper_auth.execute_auth_rest('pam/get_configuration_controller', rq, response_type=pam_pb2.PAMController) + + if config_info_rs: + return config_info_rs + else: + return None diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/__init__.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/__init__.py new file mode 100644 index 00000000..a65e1fa0 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/__init__.py @@ -0,0 +1,307 @@ + +import csv +from enum import Enum +import logging +import sys +import time +import os +from pydantic import BaseModel +from typing import Any, Dict, Optional, Tuple, Union + +from . import dag_utils, dag_crypto, exceptions +from .dag_types import SyncQuery +from .__version__ import __version__ as dag_version + +from ...proto import router_pb2, GraphSync_pb2 + +class ConnectionBase: + + ADD_DATA = "/add_data" + SYNC = "/sync" + + TIMEOUT = 30 + + def __init__(self, + is_device: bool, + logger: Optional[logging.Logger] = None, + log_transactions: Optional[bool] = None, + log_transactions_dir: Optional[str] = None, + use_read_protobuf: bool = False, + use_write_protobuf: bool = False): + + # device is a gateway device if is_device is False then we use user authentication flow + self.is_device = is_device + + if logger is None: + logger = logging.getLogger() + self.logger = logger + + # Debug tool; log transaction to the file + if log_transactions is not None: + self.log_transactions = dag_utils.value_to_boolean(log_transactions) + else: + self.log_transactions: bool = dag_utils.value_to_boolean(os.environ.get("GS_LOG_TRANS", False)) + + self.log_transactions_dir = os.environ.get("GS_LOG_TRANS_DIR", log_transactions_dir) + if self.log_transactions_dir is None: + self.log_transactions_dir = "." + + if self.log_transactions is True: + self.logger.info("keeper-dag transaction logging is ENABLED; " + f"write directory at {self.log_transactions_dir}") + + self.use_read_protobuf = use_read_protobuf + self.use_write_protobuf = use_write_protobuf + + # This should stay none for KSM + self.transmission_key = None + + def close(self): + if hasattr(self, "logger"): + self.logger = None + del self.logger + + def __del__(self): + self.close() + + def log_transaction_path(self, file: str): + return os.path.join(self.log_transactions_dir, f"graph_{file}.csv") + + @staticmethod + def get_record_uid(record: object) -> str: + pass + + @staticmethod + def get_key_bytes(record: object) -> bytes: + pass + + @staticmethod + def get_encrypted_payload_data(encrypted_payload_data: bytes) -> bytes: + try: + router_response = router_pb2.RouterResponse() + router_response.ParseFromString(encrypted_payload_data) + return router_response.encryptedPayload + except Exception as err: + raise Exception(f"Could not parse router response: {err}") + + def rest_call_to_router(self, + http_method: str, + endpoint: str, + agent: str, + payload: Optional[Union[str, bytes]] = None, + retry: int = 5, + retry_wait: float = 10, + throttle_inc_factor: float = 1.5, + timeout: Optional[int] = None, + headers: Optional[Dict] = None) -> Optional[bytes]: + return b"" + + def _endpoint(self, action: str, endpoint: Optional[str] = None) -> str: + + """ + Build the endpoint on the remote site. + + This method will attempt to fix slashes. + + :param action: + :param endpoint: + :return: + """ + + # Make sure endpoint is /path/to/endpoint; starting / and no ending / + if endpoint is not None and endpoint != "": + if isinstance(endpoint, Enum): + endpoint = endpoint.value + + while endpoint.startswith("/"): + endpoint = endpoint[1:] + while endpoint.endswith("/"): + endpoint = endpoint[:-1] + endpoint = "/" + endpoint + else: + endpoint = "" + + while action.startswith("/"): + action = action[1:] + while action.endswith("/"): + action = action[:-1] + action = "/" + action + + base = "/api/device" + if not self.is_device: + base = "/api/user" + + return base + endpoint + action + + def write_transaction_log(self, + agent: str, + endpoint: str, + graph_id: Optional[int] = None, + request: Optional[Any] = None, + response: Optional[Any] = None, + error: Optional[str] = None): + # If log transaction is True, we want to append to the log file. + + if self.log_transactions is True: + + file_name = graph_id + if file_name is None: + file_name = endpoint.replace("/", "_") + + timestamp = time.time() + + if isinstance(request, BaseModel): + request = request.model_dump_json() + elif hasattr(request, "SerializeToString"): + request = request.SerializeToString() + + if isinstance(response, BaseModel): + response = request.model_dump_json() + elif hasattr(response, "SerializeToString"): + response = request.SerializeToString() + + self.logger.info(f"TRANSACTION TIMESTAMP: {timestamp}") + filename = self.log_transaction_path(str(file_name)) + self.logger.debug(f"write to {filename}") + with open(filename, mode='a', newline='') as file: + self.logger.debug("write add_data to transaction log") + writer = csv.writer(file) + writer.writerow([ + timestamp, + sys.argv[0], + endpoint, + agent, + request, + response, + error + ]) + file.close() + + def payload_and_headers(self, payload: Any) -> Tuple[Union[str, bytes], Dict]: + + headers = {} + if isinstance(payload, BaseModel): + self.logger.debug("payload is pydantic") + payload = payload.model_dump_json() + elif hasattr(payload, "SerializeToString"): + self.logger.debug("payload is protobuf") + headers = {'Content-Type': 'application/octet-stream'} + payload = dag_crypto.encrypt_aes(payload.SerializeToString(), self.transmission_key) + else: + raise Exception("Cannot determine if the model is pydantic or protobuf.") + + return payload, headers + + def sync(self, + sync_query: Union[SyncQuery, GraphSync_pb2.GraphSyncQuery], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + + if agent is None: + f"keeper-dag/{dag_version.__version__}" + + endpoint = self._endpoint(ConnectionBase.SYNC, endpoint) + self.logger.debug(f"endpoint {endpoint}") + + try: + sync_query, headers = self.payload_and_headers(sync_query) + payload = self.rest_call_to_router(http_method="POST", + endpoint=endpoint, + agent=agent, + headers=headers, + payload=sync_query) + + if self.use_read_protobuf: + try: + self.logger.debug(f"decrypt payload with transmission key {dag_utils.kotlin_bytes(self.transmission_key)}") + payload = self.get_encrypted_payload_data(payload) + payload = dag_crypto.decrypt_aes(payload, self.transmission_key) + except Exception as err: + self.logger.error(f"Could not decrypt protobuf graph sync response: {type(err)}, {err}") + + self.write_transaction_log( + graph_id=graph_id, + request=sync_query, + response=payload, + agent=agent, + endpoint=endpoint, + error=None + ) + + return payload + + except exceptions.DAGConnectionException as err: + + self.write_transaction_log( + graph_id=graph_id, + request=sync_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise err + except Exception as err: + self.write_transaction_log( + graph_id=graph_id, + request=sync_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise exceptions.DAGException(f"Could not load the DAG structure: {err}") + + def debug_dump(self) -> str: + return "Connection does not allow debug dump." + + def add_data(self, + payload: Union[dag_types.DataPayload, GraphSync_pb2.GraphSyncAddDataRequest], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + use_protobuf: bool = False, + agent: Optional[str] = None): + + if agent is None: + f"keeper-dag/{dag_version.__version__}" + + endpoint = self._endpoint(ConnectionBase.ADD_DATA, endpoint) + self.logger.debug(f"endpoint {endpoint}") + + try: + payload, headers = self.payload_and_headers(payload) + self.rest_call_to_router(http_method="POST", + endpoint=endpoint, + payload=payload, + headers=headers, + agent=agent) + + self.write_transaction_log( + graph_id=graph_id, + request=payload, + response=None, + agent=agent, + endpoint=endpoint, + error=None + ) + except exceptions.DAGConnectionException as err: + self.write_transaction_log( + graph_id=graph_id, + request=payload, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise err + except Exception as err: + self.write_transaction_log( + graph_id=graph_id, + request=payload, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise exceptions.DAGException(f"Could not create a new DAG structure: {err}") diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/__version__.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/__version__.py new file mode 100644 index 00000000..12ce4098 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/__version__.py @@ -0,0 +1 @@ +__version__ = '1.1.0' # pragma: no cover diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection.py new file mode 100644 index 00000000..c6bdf812 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection.py @@ -0,0 +1,205 @@ +import logging +import os + +import time +from typing import Any, Dict, Optional, Tuple, Union + +import requests +from . import ConnectionBase, dag_utils, exceptions + +from keepersdk.vault import vault_online, vault_record +from keepersdk import crypto, utils +from keepersdk.authentication import endpoint + +class Connection(ConnectionBase): + + def __init__(self, + vault: vault_online.VaultOnline, + verify_ssl: bool = True, + is_ws: bool = False, + logger: Optional[logging.Logger] = None, + log_transactions: Optional[bool] = False, + log_transactions_dir: Optional[str] = None, + use_read_protobuf: bool = False, + use_write_protobuf: bool = False, + **kwargs): + + # Commander uses /api/user; hence is_device=False + super().__init__(is_device=False, + logger=logger, + log_transactions=log_transactions, + log_transactions_dir=log_transactions_dir, + use_read_protobuf=use_read_protobuf, + use_write_protobuf=use_write_protobuf) + + self.vault = vault + self.verify_ssl = dag_utils.value_to_boolean(os.environ.get("VERIFY_SSL", verify_ssl)) + self.is_ws = is_ws + + # Deprecated; setting this will override the per-transaction values. + self.transmission_key = kwargs.get("transmission_key") + self.dep_encrypted_transmission_key = kwargs.get("encrypted_transmission_key") + self.dep_encrypted_session_token = kwargs.get("encrypted_session_token") + + @staticmethod + def get_record_uid(record: vault_record.KeeperRecord) -> str: + return record.record_uid + + @staticmethod + def get_key_bytes(record: vault_record.KeeperRecord) -> bytes: + return record.record_key + + @property + def hostname(self) -> str: + # The host is connect.keepersecurity.com, connect.dev.keepersecurity.com, etc. Append "connect" in front + # of host used for Commander. + configured_host = f'connect.{self.params.config.get("server")}' + + # In GovCloud environments, the router service is not under the govcloud subdomain + if 'govcloud.' in configured_host: + # "connect.govcloud.keepersecurity.com" -> "connect.keepersecurity.com" + configured_host = configured_host.replace('govcloud.', '') + + return os.environ.get("ROUTER_HOST", configured_host) + + @property + def dag_server_url(self) -> str: + + # Allow override of the URL. If not set, get the hostname from the config. + hostname = os.environ.get("KROUTER_URL", self.hostname) + if hostname.startswith('ws') or hostname.startswith('http'): + return hostname + + use_ssl = dag_utils.value_to_boolean(os.environ.get("USE_SSL", True)) + if self.is_ws: + prot_pref = 'ws' + else: + prot_pref = 'http' + if use_ssl is True: + prot_pref += "s" + + return f'{prot_pref}://{hostname}' + + # deprecated + def get_keeper_tokens(self): + self.transmission_key = utils.generate_aes_key() + server_public_key = endpoint.SERVER_PUBLIC_KEYS[self.vault.keeper_auth.keeper_endpoint.server_key_id] + + if self.vault.keeper_auth.keeper_endpoint.server_key_id < 7: + self.dep_encrypted_transmission_key = crypto.encrypt_rsa(self.transmission_key, server_public_key) + else: + self.dep_encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, server_public_key) + self.dep_encrypted_session_token = crypto.encrypt_aes_v2( + utils.base64_url_decode(self.vault.keeper_auth.auth_context.session_token), self.transmission_key) + + def payload_and_headers(self, payload: Any) -> Tuple[Union[str, bytes], Dict]: + + # If the dep_encrypted_transmission_key, use the set value over the generated ones. + if self.dep_encrypted_transmission_key is not None: + encrypted_transmission_key = self.dep_encrypted_transmission_key + encrypted_session_token = self.dep_encrypted_session_token + + # This is what we want to use; it's different for each call. + else: + # Create a new transmission key + self.transmission_key = utils.generate_aes_key() + self.logger.debug(f"transmission key is {self.transmission_key}") + # self.params.rest_context.transmission_key = self.transmission_key + server_public_key = endpoint.SERVER_PUBLIC_KEYS[self.vault.keeper_auth.keeper_endpoint.server_key_id] + + if self.vault.keeper_auth.keeper_endpoint.server_key_id < 7: + encrypted_transmission_key = crypto.encrypt_rsa(self.transmission_key, server_public_key) + else: + encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, server_public_key) + encrypted_session_token = crypto.encrypt_aes_v2( + utils.base64_url_decode(self.vault.keeper_auth.auth_context.session_token), self.transmission_key) + + # We need the transmission_key for protobuf sync since it returns values encrypted with the transmission_key. + if self.transmission_key is None: + raise exceptions.DAGConnectionException("The transmission key has not been set. If setting encrypted_transmission_key " + "and encrypted_session_token, also set transmission_key to 32 bytes. " + "Setting the encrypted_transmission_key and encrypted_session_token is " + "deprecated.") + + payload, headers = super().payload_and_headers(payload) + + headers["TransmissionKey"] = utils.base64_url_encode(encrypted_transmission_key) + headers["Authorization"] = f'KeeperUser {utils.base64_url_encode(encrypted_session_token)}' + + return payload, headers + + def rest_call_to_router(self, + http_method: str, + endpoint: str, + agent: str, + payload: Optional[Union[str, bytes]] = None, + retry: int = 5, + retry_wait: float = 10, + throttle_inc_factor: float = 1.5, + timeout: Optional[int] = None, + headers: Optional[Dict] = None) -> Optional[bytes]: + + if timeout is None or timeout == 0: + timeout = Connection.TIMEOUT + + if isinstance(payload, str): + payload = payload.encode() + + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + + url = self.dag_server_url + endpoint + + if headers is None: + headers = {} + + attempt = 0 + while True: + try: + attempt += 1 + self.logger.debug(f"graph web service call to {url} [{attempt}/{retry}]") + response = requests.request( + method=http_method, + url=url, + verify=self.verify_ssl, + headers={ + **headers, + 'User-Agent': agent + }, + data=payload, + timeout=timeout + ) + self.logger.debug(f"response status: {response.status_code}") + response.raise_for_status() + return response.content + + except requests.exceptions.HTTPError as http_err: + + msg = http_err.response.reason + try: + content = http_err.response.content.decode() + if content is not None and content != "": + msg = "; " + content + except (Exception,): + pass + + err_msg = f"{http_err.response.status_code}, {msg}" + + if http_err.response.status_code == 429: + attempt -= 1 + retry_wait *= throttle_inc_factor + self.logger.warning("the connection to the graph service is being throttled, " + f"increasing the delay between retry: {retry_wait} seconds.") + + except Exception as err: + err_msg = str(err) + + self.logger.info(f"call to graph web service {url} had a problem: {err_msg}") + if attempt >= retry: + self.logger.error(f"call to graph web service {url}, after {retry} " + f"attempts, failed!: {err_msg}") + raise exceptions.DAGConnectionException(f"Call to graph web service {url}, after {retry} " + f"attempts, failed!: {err_msg}") + + self.logger.info(f"will retry call after {retry_wait} seconds.") + time.sleep(retry_wait) diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/__init__.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/__init__.py new file mode 100644 index 00000000..6eabb280 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/__init__.py @@ -0,0 +1,332 @@ +from __future__ import annotations +import logging +from ..__version__ import __version__ +from ....proto import GraphSync_pb2 as gs_pb2 +from ..exceptions import DAGException, DAGConnectionException +from ..dag_types import SyncQuery, DataPayload +from ..dag_utils import value_to_boolean, kotlin_bytes +from ..dag_crypto import encrypt_aes, decrypt_aes +import csv +import os +import time +import sys +from enum import Enum +from pydantic import BaseModel +from typing import Optional, Union, Any, Dict, Tuple, TYPE_CHECKING +if TYPE_CHECKING: # pragma: no cover + Logger = Union[logging.RootLogger, logging.Logger] + +# What is this? +# If used with Commander, router_abbr_pb2 will interfere with router_pb2. +# `TypeError: Couldn't build proto file into descriptor pool: duplicate symbol 'Router.RouterResponseCode'` +# Try to import the Commander first, then fallback to router_abbr_pb2. +try: + # noinspection PyUnresolvedReferences + from ....proto import router_pb2 as router_pb2 # type: ignore[import] +except (Exception,): + from ....proto import router_abbr_pb2 as router_pb2 + + +class ConnectionBase: + + ADD_DATA = "/add_data" + SYNC = "/sync" + + TIMEOUT = 30 + + def __init__(self, + is_device: bool, + logger: Optional[Logger] = None, + log_transactions: Optional[bool] = None, + log_transactions_dir: Optional[str] = None, + use_read_protobuf: bool = False, + use_write_protobuf: bool = False): + + # device is a gateway device if is_device is False then we use user authentication flow + self.is_device = is_device + + if logger is None: + logger = logging.getLogger() + self.logger = logger + + # Debug tool; log transaction to the file + if log_transactions is not None: + self.log_transactions = value_to_boolean(log_transactions) + else: + self.log_transactions: bool = value_to_boolean(os.environ.get("GS_LOG_TRANS", False)) + + self.log_transactions_dir = os.environ.get("GS_LOG_TRANS_DIR", log_transactions_dir) + if self.log_transactions_dir is None: + self.log_transactions_dir = "." + + if self.log_transactions is True: + self.logger.info("keeper-dag transaction logging is ENABLED; " + f"write directory at {self.log_transactions_dir}") + + self.use_read_protobuf = use_read_protobuf + self.use_write_protobuf = use_write_protobuf + + # This should stay none for KSM + self.transmission_key = None + + def close(self): + if hasattr(self, "logger"): + self.logger = None + del self.logger + + def __del__(self): + self.close() + + def log_transaction_path(self, file: str): + return os.path.join(self.log_transactions_dir, f"graph_{file}.csv") + + @staticmethod + def get_record_uid(record: object) -> str: + pass + + @staticmethod + def get_key_bytes(record: object) -> bytes: + pass + + @staticmethod + def get_encrypted_payload_data(encrypted_payload_data: bytes) -> bytes: + try: + router_response = router_pb2.RouterResponse() + router_response.ParseFromString(encrypted_payload_data) + return router_response.encryptedPayload + except Exception as err: + raise Exception(f"Could not parse router response: {err}") + + @staticmethod + def get_router_host(server_hostname: str): + + # Only PROD GovCloud strips the subdomain (workaround for prod infrastructure). + # DEV/QA GOV (govcloud.dev.keepersecurity.us, govcloud.qa.keepersecurity.us) keep govcloud. + if server_hostname == 'govcloud.keepersecurity.us': + configured_host = 'connect.keepersecurity.us' + else: + configured_host = f'connect.{server_hostname}' + + return os.environ.get("ROUTER_HOST", configured_host) + + def rest_call_to_router(self, + http_method: str, + endpoint: str, + agent: str, + payload: Optional[Union[str, bytes]] = None, + retry: int = 5, + retry_wait: float = 10, + throttle_inc_factor: float = 1.5, + timeout: Optional[int] = None, + headers: Optional[Dict] = None) -> Optional[bytes]: + return b"" + + def _endpoint(self, action: str, endpoint: Optional[str] = None) -> str: + + """ + Build the endpoint on the remote site. + + This method will attempt to fix slashes. + + :param action: + :param endpoint: + :return: + """ + + # Make sure endpoint is /path/to/endpoint; starting / and no ending / + if endpoint is not None and endpoint != "": + if isinstance(endpoint, Enum): + endpoint = endpoint.value + + while endpoint.startswith("/"): + endpoint = endpoint[1:] + while endpoint.endswith("/"): + endpoint = endpoint[:-1] + endpoint = "/" + endpoint + else: + endpoint = "" + + while action.startswith("/"): + action = action[1:] + while action.endswith("/"): + action = action[:-1] + action = "/" + action + + base = "/api/device" + if not self.is_device: + base = "/api/user" + + return base + endpoint + action + + def write_transaction_log(self, + agent: str, + endpoint: str, + graph_id: Optional[int] = None, + request: Optional[Any] = None, + response: Optional[Any] = None, + error: Optional[str] = None): + # If log transaction is True, we want to append to the log file. + + if self.log_transactions is True: + + file_name = graph_id + if file_name is None: + file_name = endpoint.replace("/", "_") + + timestamp = time.time() + + if isinstance(request, BaseModel): + request = request.model_dump_json() + elif hasattr(request, "SerializeToString"): + request = request.SerializeToString() + + if isinstance(response, BaseModel): + response = request.model_dump_json() + elif hasattr(response, "SerializeToString"): + response = request.SerializeToString() + + self.logger.info(f"TRANSACTION TIMESTAMP: {timestamp}") + filename = self.log_transaction_path(str(file_name)) + self.logger.debug(f"write to {filename}") + with open(filename, mode='a', newline='') as file: + self.logger.debug("write add_data to transaction log") + writer = csv.writer(file) + writer.writerow([ + timestamp, + sys.argv[0], + endpoint, + agent, + request, + response, + error + ]) + file.close() + + def payload_and_headers(self, payload: Any) -> Tuple[Union[str, bytes], Dict]: + + headers = {} + if isinstance(payload, BaseModel): + self.logger.debug("payload is pydantic") + payload = payload.model_dump_json() + elif hasattr(payload, "SerializeToString"): + self.logger.debug("payload is protobuf") + headers = {'Content-Type': 'application/octet-stream'} + payload = encrypt_aes(payload.SerializeToString(), self.transmission_key) + else: + raise Exception("Cannot determine if the model is pydantic or protobuf.") + + return payload, headers + + def sync(self, + sync_query: Union[SyncQuery, gs_pb2.GraphSyncQuery], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + + if agent is None: + f"keeper-dag/{__version__}" + + endpoint = self._endpoint(ConnectionBase.SYNC, endpoint) + self.logger.debug(f"endpoint {endpoint}") + + try: + sync_query, headers = self.payload_and_headers(sync_query) + payload = self.rest_call_to_router(http_method="POST", + endpoint=endpoint, + agent=agent, + headers=headers, + payload=sync_query) + + if self.use_read_protobuf: + try: + self.logger.debug(f"decrypt payload with transmission key {kotlin_bytes(self.transmission_key)}") + payload = self.get_encrypted_payload_data(payload) + payload = decrypt_aes(payload, self.transmission_key) + except Exception as err: + self.logger.error(f"Could not decrypt protobuf graph sync response: {type(err)}, {err}") + + self.write_transaction_log( + graph_id=graph_id, + request=sync_query, + response=payload, + agent=agent, + endpoint=endpoint, + error=None + ) + + return payload + + except DAGConnectionException as err: + + self.write_transaction_log( + graph_id=graph_id, + request=sync_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise err + except Exception as err: + self.write_transaction_log( + graph_id=graph_id, + request=sync_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise DAGException(f"Could not load the DAG structure: {err}") + + def debug_dump(self) -> str: + return "Connection does not allow debug dump." + + def add_data(self, + payload: Union[DataPayload, gs_pb2.GraphSyncAddDataRequest], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + use_protobuf: bool = False, + agent: Optional[str] = None): + + if agent is None: + f"keeper-dag/{__version__}" + + endpoint = self._endpoint(ConnectionBase.ADD_DATA, endpoint) + self.logger.debug(f"endpoint {endpoint}") + + try: + payload, headers = self.payload_and_headers(payload) + self.rest_call_to_router(http_method="POST", + endpoint=endpoint, + payload=payload, + headers=headers, + agent=agent) + + self.write_transaction_log( + graph_id=graph_id, + request=payload, + response=None, + agent=agent, + endpoint=endpoint, + error=None + ) + except DAGConnectionException as err: + self.write_transaction_log( + graph_id=graph_id, + request=payload, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise err + except Exception as err: + self.write_transaction_log( + graph_id=graph_id, + request=payload, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise DAGException(f"Could not create a new DAG structure: {err}") diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/commander.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/commander.py new file mode 100644 index 00000000..e238ba97 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/commander.py @@ -0,0 +1,208 @@ +from __future__ import annotations +import logging +from . import ConnectionBase +from ..exceptions import DAGConnectionException +from ....authentication import endpoint +from ..dag_utils import value_to_boolean +import os +import requests +import time + +try: # pragma: no cover + from .... import crypto, utils +except ImportError: # pragma: no cover + raise Exception("Please install the keepercommander module to use the Commander connection.") + +from typing import Optional, Union, Dict, Tuple, Any, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from ....vault import vault_online + from ....vault.vault_record import KeeperRecord + Content = Union[str, bytes, dict] + QueryValue = Union[list, dict, str, float, int, bool] + Logger = Union[logging.RootLogger, logging.Logger] + + +class Connection(ConnectionBase): + + def __init__(self, + vault: vault_online.VaultOnline, + verify_ssl: bool = True, + is_ws: bool = False, + logger: Optional[Logger] = None, + log_transactions: Optional[bool] = False, + log_transactions_dir: Optional[str] = None, + use_read_protobuf: bool = False, + use_write_protobuf: bool = False, + **kwargs): + + # Commander uses /api/user; hence is_device=False + super().__init__(is_device=False, + logger=logger, + log_transactions=log_transactions, + log_transactions_dir=log_transactions_dir, + use_read_protobuf=use_read_protobuf, + use_write_protobuf=use_write_protobuf) + + self.vault = vault + self.verify_ssl = value_to_boolean(os.environ.get("VERIFY_SSL", verify_ssl)) + self.is_ws = is_ws + + # Deprecated; setting this will override the per-transaction values. + self.transmission_key = kwargs.get("transmission_key") + self.dep_encrypted_transmission_key = kwargs.get("encrypted_transmission_key") + self.dep_encrypted_session_token = kwargs.get("encrypted_session_token") + + @staticmethod + def get_record_uid(record: KeeperRecord) -> str: + return record.record_uid + + @staticmethod + def get_key_bytes(record: KeeperRecord) -> bytes: + return record.get_key_bytes() + + @property + def hostname(self) -> str: + return self.get_router_host(self.vault.keeper_auth.keeper_endpoint.server) + + @property + def dag_server_url(self) -> str: + + # Allow override of the URL. If not set, get the hostname from the config. + hostname = os.environ.get("KROUTER_URL", self.hostname) + if hostname.startswith('ws') or hostname.startswith('http'): + return hostname + + use_ssl = value_to_boolean(os.environ.get("USE_SSL", True)) + if self.is_ws: + prot_pref = 'ws' + else: + prot_pref = 'http' + if use_ssl is True: + prot_pref += "s" + + return f'{prot_pref}://{hostname}' + + # deprecated + def get_keeper_tokens(self): + self.transmission_key = utils.generate_aes_key() + server_public_key = endpoint.SERVER_PUBLIC_KEYS[self.vault.keeper_auth.keeper_endpoint.server_key_id] + + if self.vault.keeper_auth.keeper_endpoint.server_key_id < 7: + self.dep_encrypted_transmission_key = crypto.encrypt_rsa(self.transmission_key, server_public_key) + else: + self.dep_encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, server_public_key) + self.dep_encrypted_session_token = crypto.encrypt_aes_v2( + utils.base64_url_decode(self.vault.keeper_auth.auth_context.session_token), self.transmission_key) + + def payload_and_headers(self, payload: Any) -> Tuple[Union[str, bytes], Dict]: + + # If the dep_encrypted_transmission_key, use the set value over the generated ones. + if self.dep_encrypted_transmission_key is not None: + encrypted_transmission_key = self.dep_encrypted_transmission_key + encrypted_session_token = self.dep_encrypted_session_token + + # This is what we want to use; it's different for each call. + else: + # Create a new transmission key + self.transmission_key = utils.generate_aes_key() + self.logger.debug(f"transmission key is {self.transmission_key}") + # self.params.rest_context.transmission_key = self.transmission_key + server_public_key = endpoint.SERVER_PUBLIC_KEYS[self.vault.keeper_auth.keeper_endpoint.server_key_id] + + if self.vault.keeper_auth.keeper_endpoint.server_key_id < 7: + encrypted_transmission_key = crypto.encrypt_rsa(self.transmission_key, server_public_key) + else: + encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, server_public_key) + encrypted_session_token = crypto.encrypt_aes_v2( + utils.base64_url_decode(self.vault.keeper_auth.auth_context.session_token), self.transmission_key) + + # We need the transmission_key for protobuf sync since it returns values encrypted with the transmission_key. + if self.transmission_key is None: + raise DAGConnectionException("The transmission key has not been set. If setting encrypted_transmission_key " + "and encrypted_session_token, also set transmission_key to 32 bytes. " + "Setting the encrypted_transmission_key and encrypted_session_token is " + "deprecated.") + + payload, headers = super().payload_and_headers(payload) + + headers["TransmissionKey"] = utils.base64_url_encode(encrypted_transmission_key) + headers["Authorization"] = f'KeeperUser {utils.base64_url_encode(encrypted_session_token)}' + + return payload, headers + + def rest_call_to_router(self, + http_method: str, + endpoint: str, + agent: str, + payload: Optional[Union[str, bytes]] = None, + retry: int = 5, + retry_wait: float = 10, + throttle_inc_factor: float = 1.5, + timeout: Optional[int] = None, + headers: Optional[Dict] = None) -> Optional[bytes]: + + if timeout is None or timeout == 0: + timeout = Connection.TIMEOUT + + if isinstance(payload, str): + payload = payload.encode() + + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + + url = self.dag_server_url + endpoint + + if headers is None: + headers = {} + + attempt = 0 + while True: + try: + attempt += 1 + self.logger.debug(f"graph web service call to {url} [{attempt}/{retry}]") + response = requests.request( + method=http_method, + url=url, + verify=self.verify_ssl, + headers={ + **headers, + 'User-Agent': agent + }, + data=payload, + timeout=timeout + ) + self.logger.debug(f"response status: {response.status_code}") + response.raise_for_status() + return response.content + + except requests.exceptions.HTTPError as http_err: + + msg = http_err.response.reason + try: + content = http_err.response.content.decode() + if content is not None and content != "": + msg = "; " + content + except (Exception,): + pass + + err_msg = f"{http_err.response.status_code}, {msg}" + + if http_err.response.status_code == 429: + attempt -= 1 + retry_wait *= throttle_inc_factor + self.logger.warning("the connection to the graph service is being throttled, " + f"increasing the delay between retry: {retry_wait} seconds.") + + except Exception as err: + err_msg = str(err) + + self.logger.info(f"call to graph web service {url} had a problem: {err_msg}") + if attempt >= retry: + self.logger.error(f"call to graph web service {url}, after {retry} " + f"attempts, failed!: {err_msg}") + raise DAGConnectionException(f"Call to graph web service {url}, after {retry} " + f"attempts, failed!: {err_msg}") + + self.logger.info(f"will retry call after {retry_wait} seconds.") + time.sleep(retry_wait) diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/ksm.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/ksm.py new file mode 100644 index 00000000..d604a281 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/ksm.py @@ -0,0 +1,331 @@ +# from __future__ import annotations +# from . import ConnectionBase +# from ..dag_utils import value_to_boolean +# from ..exceptions import DAGException, DAGConnectionException + +# from cryptography.hazmat.primitives import hashes +# from cryptography.hazmat.primitives.asymmetric import ec +# from cryptography.hazmat.primitives.serialization import load_der_private_key + +# # try: # pragma: no cover +# # from keeper_secrets_manager_core import utils +# # from keeper_secrets_manager_core.configkeys import ConfigKeys +# # from keeper_secrets_manager_core.storage import InMemoryKeyValueStorage, KeyValueStorage +# # from keeper_secrets_manager_core.utils import url_safe_str_to_bytes, bytes_to_base64, generate_random_bytes +# # except ImportError: # pragma: no cover +# # raise Exception("Please install the keeper_secrets_manager_core module to use the Ksm connection.") + +# import logging +# import json +# import os +# import requests +# import time +# from typing import Union, Optional, Tuple, Dict, Any, TYPE_CHECKING + +# if TYPE_CHECKING: # pragma: no cover +# # from keeper_secrets_manager_core.storage import KeyValueStorage +# # from keeper_secrets_manager_core.dto.dtos import Record +# KsmConfig = Union[dict, str, KeyValueStorage] +# Content = Union[str, bytes, dict] +# QueryValue = Union[list, dict, str, float, int, bool] +# Logger = Union[logging.RootLogger, logging.Logger] + + +# class Connection(ConnectionBase): + +# KEEPER_CLIENT = 'ms16.5.0' + +# def __init__(self, +# config: Union[str, dict, KeyValueStorage], +# verify_ssl: bool = None, +# logger: Optional[Logger] = None, +# log_transactions: Optional[bool] = None, +# log_transactions_dir: Optional[str] = None, +# use_read_protobuf: bool = False, +# use_write_protobuf: bool = False): + +# # KSM uses /api/device; hence is_device=True +# super().__init__(is_device=True, +# logger=logger, +# log_transactions=log_transactions, +# log_transactions_dir=log_transactions_dir, +# use_read_protobuf=use_read_protobuf, +# use_write_protobuf=use_write_protobuf) + +# if self.use_read_protobuf: +# self.logger.info("KSM cannot use protobuf for reading the graph, using JSON.") +# self.use_read_protobuf = False +# if self.use_write_protobuf: +# self.logger.info("KSM cannot use protobuf for writing to the graph, using JSON.") +# self.use_read_protobuf = False + +# if InMemoryKeyValueStorage.is_base64(config): +# config = utils.base64_to_string(config) +# if isinstance(config, str): +# try: +# config = json.loads(config) +# except json.JSONDecodeError as err: +# raise DAGException(f"The configuration JSON could not be decoded: {err}") + +# if isinstance(config, dict) is False and isinstance(config, KeyValueStorage) is False: +# raise DAGException("The configuration is not a dictionary.") + +# if verify_ssl is None: +# verify_ssl = value_to_boolean(os.environ.get("VERIFY_SSL", "TRUE")) + +# self.config = config +# self.verify_ssl = verify_ssl +# self._signature = None +# self._challenge_str = None + +# def close(self): +# super().close() +# if hasattr(self, "config"): +# self.config = None +# del self.config + +# @staticmethod +# def get_record_uid(record: Record) -> str: +# return record.uid + +# @staticmethod +# def get_key_bytes(record: Record) -> bytes: +# return record.record_key_bytes + +# def get_config_value(self, key: ConfigKeys) -> str: +# if isinstance(self.config, KeyValueStorage): +# return self.config.get(key) +# else: +# return self.config.get(key.value) + +# @property +# def hostname(self) -> str: +# return os.environ.get("ROUTER_HOST", self.get_config_value(ConfigKeys.KEY_HOSTNAME)) + +# @property +# def client_id(self) -> str: +# return self.get_config_value(ConfigKeys.KEY_CLIENT_ID) + +# @property +# def private_key(self) -> str: +# return self.get_config_value(ConfigKeys.KEY_PRIVATE_KEY) + +# @property +# def app_key(self) -> str: +# return self.get_config_value(ConfigKeys.KEY_APP_KEY) + +# def router_url_from_ksm_config(self) -> str: +# return self.get_router_host(self.hostname) + +# def ws_router_url_from_ksm_config(self, is_ws: bool = False) -> str: + +# router_host = self.router_url_from_ksm_config() + +# kpam_router_ssl_enabled_env = value_to_boolean(os.environ.get("USE_SSL", True)) + +# if is_ws: +# prot_pref = 'ws' +# else: +# prot_pref = 'http' + +# if not kpam_router_ssl_enabled_env: +# return f'{prot_pref}://{router_host}' +# else: +# return f'{prot_pref}s://{router_host}' + +# def http_router_url_from_ksm_config_or_env(self) -> str: + +# router_host_from_env = os.getenv("KROUTER_URL") +# if router_host_from_env: +# router_http_host = router_host_from_env +# else: +# router_http_host = self.ws_router_url_from_ksm_config() + +# return router_http_host.replace('ws', 'http') + +# def authenticate(self, +# agent: str, +# refresh: bool = False, +# retry: int = 5, +# retry_wait: float = 10.0, +# throttle_inc_factor: float = 1.5, +# timeout: Optional[int] = None) -> Tuple[str, str]: + +# if self._signature is None or refresh is True: + +# self.logger.debug(f"signature is blank or needs to be refresh {refresh}") + +# if timeout is None or timeout == 0: +# timeout = Connection.TIMEOUT + +# router_http_host = self.http_router_url_from_ksm_config_or_env() +# url = f'{router_http_host}/api/device/get_challenge' + +# self._signature = None + +# attempt = 0 +# while True: +# try: +# attempt += 1 +# response = requests.get(url, +# verify=self.verify_ssl, +# timeout=timeout, +# headers={ +# "User-Agent": agent +# }) +# response.raise_for_status() + +# self._challenge_str = response.text +# if self._challenge_str is None or self._challenge_str == "": +# raise Exception("Challenge text is blank. Cannot authenticate into the DAG web service.") + +# private_key_der_bytes = url_safe_str_to_bytes(self.private_key) +# client_id_bytes = url_safe_str_to_bytes(self.client_id) + +# self.logger.debug('adding challenge to the signature before connecting to the router') +# challenge_bytes = url_safe_str_to_bytes(self._challenge_str) +# client_id_bytes = client_id_bytes + challenge_bytes + +# pk = load_der_private_key(private_key_der_bytes, password=None) +# sig = pk.sign(client_id_bytes, ec.ECDSA(hashes.SHA256())) + +# self._signature = bytes_to_base64(sig) +# break + +# except requests.exceptions.HTTPError as http_err: + +# msg = http_err.response.reason +# try: +# content = http_err.response.content.decode() +# if content is not None and content != "": +# msg = "; " + content +# except (Exception,): +# pass + +# err_msg = f"{http_err.response.status_code}, {msg}" + +# if http_err.response.status_code == 429: +# retry_wait *= throttle_inc_factor +# attempt -= 1 +# self.logger.warning( +# "the connection to the graph service, for authentication, is being throttled; " +# f"increasing delay between retry: {retry_wait} seconds") + +# except Exception as err: +# err_msg = str(err) + +# self.logger.info(f"call to challenge had a problem: {err_msg}.") +# if attempt >= retry: +# raise DAGConnectionException(f"Call to challenge {url}, after {retry} " +# f"attempts, failed!: {err_msg}") + +# self.logger.info(f"will retry call after {retry_wait} seconds.") +# time.sleep(retry_wait) + +# return self._signature, self._challenge_str + +# def payload_and_headers(self, payload: Any) -> Tuple[Union[str, bytes], Dict]: + +# # Make sure the transmission_key is None. +# # This acts as a flag to indicate if we need to decrypt the response. +# self.transmission_key = None + +# payload, headers = super().payload_and_headers(payload) + +# return payload, headers + +# def rest_call_to_router(self, +# http_method: str, +# endpoint: str, +# agent: str, +# payload: Optional[Union[str, bytes]] = None, +# retry: int = 5, +# retry_wait: float = 10.0, +# throttle_inc_factor: float = 1.5, +# timeout: Optional[int] = None, +# headers: Optional[Dict] = None) -> Optional[bytes]: + +# if timeout is None or timeout == 0: +# timeout = Connection.TIMEOUT + +# if headers is None: +# headers = {} + +# if isinstance(payload, str): +# payload = payload.encode() + +# router_host = self.http_router_url_from_ksm_config_or_env() +# url = router_host + endpoint + +# refresh = False +# attempt = 0 +# while True: + +# attempt += 1 + +# # Keep authenticate outside the call router try. +# # This is to prevent too many retries. +# # For example, 3 retry of the auth, 3 retry of the request, will be 9 retries. +# signature, challenge_str = self.authenticate(refresh=refresh, agent=agent) +# headers = { +# **headers, +# "Signature": signature, +# "ClientVersion": Connection.KEEPER_CLIENT, +# "Authorization": f'KeeperDevice {self.client_id}', +# "Challenge": challenge_str, +# "User-Agent": agent +# } +# self.logger.debug(f'connecting with headers: {headers}') + +# try: +# self.logger.debug(f"DAG web service call to {url} [{attempt}/{retry}]") +# response = requests.request( +# method=http_method, +# url=url, +# data=payload, +# verify=self.verify_ssl, +# timeout=timeout, +# headers=headers, +# ) + +# self.logger.debug(f"response status: {response.status_code}") + +# # If we get a 401 Unauthorized, and we have not yet refreshed, +# # refresh the signature. +# if response.status_code == 401 and refresh is False: +# response.close() +# self.logger.debug("rest call was Unauthorized") + +# # The attempt didn't count. +# # We get one refresh, then it becomes an exception. +# refresh = True +# attempt -= 1 +# continue + +# response.raise_for_status() +# return response.content + +# # Handle errors outside of requests +# except requests.exceptions.HTTPError as http_err: + +# err_msg = f"{http_err.response.status_code}, {http_err.response.reason}, {http_err.response.content}" +# content = http_err.response.reason + +# if http_err.response.status_code == 429: +# retry_wait *= throttle_inc_factor +# attempt -= 1 +# self.logger.warning("the connection to the graph service is being throttled, " +# f"increasing the delay between retry: {retry_wait} seconds.") + +# except Exception as err: +# err_msg = str(err) +# content = None + +# self.logger.info(f"call to graph web service had a problem: {err_msg}, {content}") +# if attempt >= retry: +# self.logger.info(f"payload: {payload}") +# raise DAGConnectionException(f"Call to graph web service {url}, after {retry} " +# f"attempts, failed!: {err_msg}: {content} : {payload}") + +# self.logger.info(f"will retry call after {retry_wait} seconds.") +# time.sleep(retry_wait) diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/local.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/local.py new file mode 100644 index 00000000..879c53b9 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/connection/local.py @@ -0,0 +1,646 @@ +from . import ConnectionBase +from ..struct.protobuf import DataStruct as PbDataStruct +from ....proto import GraphSync_pb2 as gs_pb2 +from ..dag_types import DataPayload, EdgeType, SyncQuery, Ref, RefType, DAGData, SyncDataItem, SyncData +from ..dag_crypto import bytes_to_urlsafe_str, urlsafe_str_to_bytes +from ..dag_utils import value_to_boolean +from .... import utils +import json +import os +import logging +from enum import Enum +from tabulate import tabulate + +try: # pragma: no cover + import sqlite3 + from contextlib import closing +except ImportError: + raise Exception("Please install the sqlite3 module to use the Local connection.") + +from typing import Optional, Union, Any, TYPE_CHECKING # pragma: no cover +if TYPE_CHECKING: + Logger = Union[logging.RootLogger, logging.Logger] + + +class Connection(ConnectionBase): + + """ + BIG TIME NOTE + + This is a fake DAG engine used for unit tests. + It tries best to emulate krouter/workflow. + This is no substitute for testing against a krouter instance. + """ + + DB_FILE = "local_dag.db" + + def __init__(self, + limit: int = 100, + db_file: Optional[str] = None, + db_dir: Optional[str] = None, + logger: Optional[Any] = None, + log_transactions: Optional[bool] = None, + log_transactions_dir: Optional[str] = None, + use_read_protobuf: bool = False, + use_write_protobuf: bool = False): + + super().__init__(is_device=False, + logger=logger, + log_transactions=log_transactions, + log_transactions_dir=log_transactions_dir, + use_read_protobuf=use_read_protobuf, + use_write_protobuf=use_write_protobuf) + + if db_file is None: + db_file = os.environ.get("LOCAL_DAG_DB_FILE", Connection.DB_FILE) + if db_dir is None: + db_dir = os.environ.get("LOCAL_DAG_DIR", os.environ.get("HOME", os.environ.get("USERPROFILE", "./"))) + + self.allow_debug = value_to_boolean(os.environ.get("GS_CONN_DEBUG", False)) + if self.allow_debug is True: + self.debug("enabling GraphSync connection logging") + + self.db_file = os.path.join(db_dir, db_file) + self.limit = limit + + self.create_database() + + def debug(self, msg): + if self.allow_debug: + self.logger.debug(f"GraphSync LOCAL: {msg}") + + @staticmethod + def get_record_uid(record: object) -> bytes: + if hasattr(record, "record_uid"): + return getattr(record, "record_uid") + elif hasattr(record, "uid"): + return getattr(record, "uid") + raise Exception(f"Cannot find the record uid in object type: {type(record)}.") + + @staticmethod + def get_key_bytes(record: object) -> bytes: + if hasattr(record, "record_key_bytes"): + return getattr(record, "record_key_bytes") + elif hasattr(record, "record_key"): + return getattr(record, "record_key") + raise Exception("Cannot find the record key bytes in object.") + + def clear_database(self): + try: + os.unlink(self.db_file) + except (Exception,): + pass + + def create_database(self): + + self.debug("create local dag database") + + if os.path.isfile(self.db_file): + return False + + with closing(sqlite3.connect(self.db_file)) as connection: + with closing(connection.cursor()) as cursor: + + # This is based on workflow, Database.kt. + # The UIDs are stored a character instead of bytes to make them more readable for debugging. + + # The 'type' columns are stored a TEXT. + # This is because the WS wants text for the enum, but stores + # it as an INTEGER. + # We are just going to store it as a TEXT and avoid the middle man. + + cursor.execute( + """ +CREATE TABLE IF NOT EXISTS dag_edges ( + graph_id INTEGER, + edge_id INTEGER PRIMARY KEY AUTOINCREMENT, + type TEXT NOT NULL, + head CHARACTER(22) NOT NULL, + tail CHARACTER(22) NOT NULL, + data BLOB, + origin CHARACTER(22), + path TEXT, + created timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + creator_id BLOB(16) DEFAULT NULL, + creator_type INTEGER DEFAULT NULL, + creator_name TEXT DEFAULT NULL, + FOREIGN KEY(head) REFERENCES dag_vertices(vertex_id), + FOREIGN KEY(tail) REFERENCES dag_vertices(vertex_id) +) + """ + ) + + cursor.execute( + """ +CREATE TABLE IF NOT EXISTS dag_vertices ( + vertex_id CHARACTER(22) NOT NULL, + type TEXT NOT NULL, + name TEXT, + owner_id BLOB(16) DEFAULT NULL +) + """ + ) + + cursor.execute( + """ +CREATE TABLE IF NOT EXISTS dag_streams ( + graph_id INTEGER, + sync_point INTEGER PRIMARY KEY AUTOINCREMENT, + vertex_id CHARACTER(22) NOT NULL, + edge_id INTEGER NOT NULL, + count INTEGER NOT NULL DEFAULT 0, + deletion INTEGER NOT NULL DEFAULT 0, + UNIQUE(vertex_id,edge_id), + FOREIGN KEY(vertex_id) REFERENCES dag_vertices(vertex_id), + FOREIGN KEY(edge_id) REFERENCES dag_edges(edge_id) +) + """ + ) + connection.commit() + + os.chmod(self.db_file, 0o777) + return None + + @staticmethod + def _payload_to_json(payload: Union[DataPayload, str]) -> dict: + + # if payload is DataPayload + payload_data = "{}" + if isinstance(payload, DataPayload): + payload_data = payload.model_dump_json() + elif isinstance(payload, str): + payload_data = payload + + # make sure it is a valid json and raise and exception if not. make an exception for the case of a string + # that is a valid json + if not payload_data.startswith('{') and not payload_data.endswith('}'): + raise Exception(f'Invalid payload: {payload_data}') + + # double check if it is a valid json inside the string + json.loads(payload_data) + + return json.loads(payload_data) + + def _find_stream_id(self, payload: DataPayload): + + data = Connection._payload_to_json(payload) + + # Find the vertex that does not belong to any other vertex. + # This is normally root for a full DAG, but will be a vertex if adding additional edges. + # 100% sure this could be written better. + # 1000% sure this could be written better. + # TODO: Only refs that are type PAM_NETWORK or PAM_USER can contain the stream id. + # Change code to ignore all other ref types. + + self.debug("finding stream id") + + # First check if we can route with existing edges in the database. + stream_id = None + with closing(sqlite3.connect(self.db_file)) as connection: + with closing(connection.cursor()) as cursor: + + graph_id = data.get("graphId") + + stream_ids = {} + + runs = 0 + for item in data.get("dataList"): + + # Get the head UID of the edge and then find an edge where the UID is the tail. + # If we find an edge, use its head to find an edge where the UID is the tail. + # Repeat until we can't find and edge, that is a stream ID + # Tally all the stream ID and take the best. + item_stream_id = item.get("ref")["value"] + current_stream_id = item_stream_id + while True: + self.debug(f" check stream id {current_stream_id}") + sql = "SELECT head, edge_id FROM dag_edges WHERE tail=? AND graph_id=? AND type != ?" + res = cursor.execute(sql, (current_stream_id, graph_id, EdgeType.DATA.value)) + row = res.fetchone() + if row is None: + self.debug(f" no edge found") + if current_stream_id == item_stream_id: + current_stream_id = None + break + current_stream_id = row[0] + self.debug(f" got {current_stream_id}") + + if current_stream_id is not None: + if item_stream_id not in stream_ids: + stream_ids[current_stream_id] = 0 + stream_ids[current_stream_id] += 1 + else: + # If we didn't find anything with the tail, check starting with the head. + item_stream_id = item.get("parentRef")["value"] + current_stream_id = item_stream_id + while True: + self.debug(f" check stream id {current_stream_id}") + sql = "SELECT head, edge_id FROM dag_edges WHERE tail=? AND graph_id=? AND type != ?" + res = cursor.execute(sql, (current_stream_id, graph_id, EdgeType.DATA.value)) + row = res.fetchone() + if row is None: + self.debug(f" no edge found") + if current_stream_id == item_stream_id: + current_stream_id = None + break + current_stream_id = row[0] + self.debug(f" got {current_stream_id}") + + if current_stream_id is not None: + if item_stream_id not in stream_ids: + stream_ids[current_stream_id] = 0 + stream_ids[current_stream_id] += 1 + + # Until we rewrite this, exit after we check 3 edges. + # This will slow down after a bunch of edges are added. + # We also fixed stuff in our code to prevent the errors we were seeing. + # Might want to switch to recursion. + # https://www.sqlite.org/lang_with.html + if runs > 3: + break + runs += 1 + + if len(stream_ids) > 0: + sorted_stream_ids = [k for k, v in sorted(stream_ids.items(), key=lambda i: i[1])] + stream_id = sorted_stream_ids.pop() + + # If the stream id is None, this is the first save of the DAG. + # No edges existed. + # Compare the data list items. + # The one without an edge with a tail if the stream id. + if stream_id is None: + self.debug("stream id None, edges might be new") + # Get a starting spot + found = {} + for item in data.get("dataList"): + head_uid = item.get("parentRef")["value"] + found[head_uid] = True + for item in data.get("dataList"): + tail_uid = item.get("ref")["value"] + found.pop(tail_uid, None) + stream_ids = [uid for uid in found] + if len(stream_ids) > 0: + stream_id = stream_ids[0] + + # If we can't find stream ID, assume it's on the first item in the dataList + if stream_id is None: + item = data.get("dataList")[0] + stream_id = item.get("parentRef")["value"] or item.get("ref")["value"] + + return stream_id + + @staticmethod + def _add_data_pb_to_pydantic(payload: gs_pb2.GraphSyncAddDataRequest) -> DataPayload: + + data = [] + for item in payload.data: + data.append( + DAGData( + type=PbDataStruct.PB_TO_DATA_MAP.get(item.type), + content=bytes_to_urlsafe_str(item.content), + ref=Ref( + type=PbDataStruct.PB_TO_REF_MAP.get(item.ref.type), + value=bytes_to_urlsafe_str(item.ref.value), + name=item.ref.name, + ), + parentRef=Ref( + type=PbDataStruct.PB_TO_REF_MAP.get(item.parentRef.type), + value=bytes_to_urlsafe_str(item.parentRef.value), + name=item.parentRef.name + ), + path=item.path + ) + ) + + return DataPayload( + origin=Ref( + type=PbDataStruct.PB_TO_REF_MAP.get(payload.origin.type), + value=bytes_to_urlsafe_str(payload.origin.value), + name=payload.origin.name, + ), + dataList=data + ) + + def add_data(self, + payload: Union[DataPayload, gs_pb2.GraphSyncAddDataRequest], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + use_protobuf: bool = False, + agent: Optional[str] = None): + + # Convert protobuf to the pydantic structure + if isinstance(payload, gs_pb2.GraphSyncAddDataRequest): + payload = self._add_data_pb_to_pydantic(payload) + + stream_id = self._find_stream_id(payload) + self.debug(f"STREAM ID IS {stream_id}") + + endpoint = self._endpoint( + action="/add_data", + endpoint=endpoint) + self.logger.debug(f"endpoint, local test = {endpoint}") + + data = Connection._payload_to_json(payload) + + self.write_transaction_log( + graph_id=payload.graphId, + request=json.dumps(data), + response=None, + agent=agent, + endpoint=endpoint, + error=None + ) + + with closing(sqlite3.connect(self.db_file)) as connection: + with closing(connection.cursor()) as cursor: + + origin_id = data.get("origin")["value"] + graph_id = data.get("graphId") + + saved_vertex = {} + for item in data.get("dataList"): + + tail_uid = item.get("ref")["value"] + tail_type = item.get("ref")["type"] + tail_name = item.get("ref")["name"] + + head_uid = None + head_type = None + head_name = None + if item.get("parentRef") is not None: + head_uid = item.get("parentRef")["value"] + head_type = item.get("parentRef")["type"] + head_name = item.get("parentRef")["name"] + + edge_type = item.get("type") + path = item.get("path") + + content = item.get("content") + if content is not None: + content = utils.base64_url_decode(content) + + sql = "INSERT INTO dag_edges (type, head, tail, data, origin, graph_id, path) " + sql += "VALUES (?,?,?,?,?,?,?)" + cursor.execute(sql, ( + edge_type, + head_uid, + tail_uid, + content, + origin_id, + graph_id, + path + )) + edge_id = cursor.lastrowid + + sql = "INSERT INTO dag_streams (graph_id, vertex_id, edge_id, count) VALUES (?, ?, ?, ?)" + cursor.execute(sql, ( + graph_id, + stream_id, + edge_id, + 1 + )) + + if saved_vertex.get(tail_uid) is None: + # Type is RefType enum value + sql = "INSERT INTO dag_vertices (vertex_id, type, name) VALUES (?, ?, ?)" + cursor.execute(sql, ( + tail_uid, + tail_type, + tail_name + )) + saved_vertex[tail_uid] = True + if saved_vertex.get(head_uid) is None: + # Type is RefType enum value + sql = "INSERT INTO dag_vertices (vertex_id, type, name) VALUES (?, ?, ?)" + cursor.execute(sql, ( + head_uid, + head_type, + head_name + )) + saved_vertex[head_uid] = True + + connection.commit() + + @staticmethod + def _sync_pb_to_pydantic(payload: gs_pb2.GraphSyncQuery) -> SyncQuery: + + return SyncQuery( + streamId=bytes_to_urlsafe_str(payload.streamId), + graphId=payload.syncPoint, + syncPoint=payload.syncPoint + ) + + def sync(self, + sync_query: Union[SyncQuery, gs_pb2.GraphSyncQuery], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + + is_protobuf = False + if isinstance(sync_query, gs_pb2.GraphSyncQuery): + is_protobuf = True + sync_query = self._sync_pb_to_pydantic(sync_query) + + edge_type_map = { + EdgeType.DATA.value: "data", + EdgeType.KEY.value: "key", + EdgeType.LINK.value: "link", + EdgeType.ACL.value: "acl", + EdgeType.DELETION.value: "deletion", + EdgeType.DENIAL.value: "denial", + EdgeType.UNDENIAL.value: "undenial", + } + + stream_id = sync_query.streamId + graph_id = sync_query.graphId + sync_point = sync_query.syncPoint + + endpoint = self._endpoint( + action="/sync", + endpoint=endpoint) + self.logger.debug(f"endpoint, local test = {endpoint}") + + if isinstance(sync_query.graphId, Enum): + graph_id = sync_query.graphId.value + + self.write_transaction_log( + graph_id=graph_id, + request=sync_query, + response=None, + agent=agent, + endpoint=endpoint, + error=None + ) + + has_more = False + new_sync_point = 0 + data = [] + + with closing(sqlite3.connect(self.db_file)) as connection: + with closing(connection.cursor()) as cursor: + self.debug(f"... loading DAG, {stream_id}, {sync_point}, {self.limit + 1}") + + args = [stream_id, sync_point, graph_id] + sql = "SELECT sync_point, edge_id FROM dag_streams WHERE vertex_id = ? AND deletion = 0 "\ + "AND sync_point > ? AND graph_id=? ORDER BY sync_point ASC LIMIT ?" + args.append(self.limit + 1) + + res = cursor.execute(sql, tuple(args)) + rows = list(res.fetchall()) + if len(rows) > self.limit: + has_more = True + rows.pop() + self.logger.debug(f"... loaded {len(rows)} edges") + for row in rows: + new_sync_point = row[0] + + args = [row[1], graph_id] + sql = "SELECT head, tail, data, path, type FROM dag_edges WHERE edge_id = ? AND graph_id=?" + res = cursor.execute(sql, tuple(args)) + edges = res.fetchone() + + # If the head and tail are the same (DATA edge), then parent_ref is None. + # Else include a parent_ref + parent_ref = None + if edges[1] != edges[0]: + + sql = "SELECT type FROM dag_vertices WHERE vertex_id = ?" + res = cursor.execute(sql, (edges[0],)) + head_vertex = res.fetchone() + + parent_ref = { + "type": head_vertex[0], + "value": edges[0], + "name": None + } + + sql = "SELECT type FROM dag_vertices WHERE vertex_id = ?" + res = cursor.execute(sql, (edges[1],)) + tail_vertex = res.fetchone() + + if is_protobuf: + data.append( + gs_pb2.GraphSyncDataPlus( + data=gs_pb2.GraphSyncData( + type=PbDataStruct.DATA_TO_PB_MAP.get(EdgeType.find_enum(edges[4])), + content=edges[2], + path=edges[3], + ref=gs_pb2.GraphSyncRef( + type=PbDataStruct.REF_TO_PB_MAP.get(EdgeType.find_enum(tail_vertex[0])), + value=urlsafe_str_to_bytes(edges[1]) + ), + parentRef=gs_pb2.GraphSyncRef( + type=PbDataStruct.REF_TO_PB_MAP.get(EdgeType.find_enum(parent_ref.get("type"))), + value=urlsafe_str_to_bytes(parent_ref.get("value")) + ) if parent_ref else None, + ) + ) + ) + else: + content = edges[2] + if content is not None: + content = utils.base64_url_decode(content) + + data.append( + SyncDataItem( + type=EdgeType.find_enum(edges[4]), + content=content, + path=edges[3], + deletion=False, + ref=Ref( + type=RefType.find_enum(tail_vertex[0]), + value=edges[1] + ), + parentRef=Ref( + type=RefType.find_enum(parent_ref.get("type")), + value=parent_ref.get("value") + ) if parent_ref else None, + ) + ) + + if is_protobuf: + return gs_pb2.GraphSyncResult( + streamId=urlsafe_str_to_bytes(stream_id), + syncPoint=new_sync_point, + data=data, + hasMore=has_more + ).SerializeToString() + else: + return SyncData( + syncPoint=new_sync_point, + data=data, + hasMore=has_more + ).model_dump_json().encode() + + def debug_dump(self) -> str: + + ret = "" + + with closing(sqlite3.connect(self.db_file)) as connection: + with closing(connection.cursor()) as cursor: + + cols = ["graph_id", "edge_id", "type", "head", "tail", "data", "origin", "path", "created", + "creator_id", "creator_type", "creator_name"] + + sql = f"SELECT {','.join(cols)} FROM dag_edges ORDER BY edge_id DESC" + res = cursor.execute(sql,) + + ret += "dag_edges\n" + ret += "=========\n" + table = [] + for row in res.fetchall(): + table.append(list(row)) + + ret += tabulate(table, cols) + "\n\n" + + cols = ["e.graph_id", "e.edge_id", "v.vertex_id", "v.type", "v.name", "v.owner_id"] + + sql = f"SELECT {','.join(cols)} "\ + "FROM dag_vertices v "\ + "INNER JOIN dag_edges e ON e.tail = v.vertex_id "\ + "ORDER BY e.graph_id DESC, e.edge_id DESC" + res = cursor.execute(sql,) + + ret += "dag_vertices\n" + ret += "============\n" + table = [] + for row in res.fetchall(): + table.append(list(row)) + + ret += tabulate(table, cols) + "\n\n" + + cols = ["graph_id", "edge_id", "sync_point", "vertex_id", "count", "deletion"] + + sql = f"SELECT {','.join(cols)} FROM dag_streams ORDER BY edge_id DESC" + res = cursor.execute(sql,) + + ret += "dag_streams\n" + ret += "===========\n" + table = [] + for row in res.fetchall(): + table.append(list(row)) + + ret += tabulate(table, cols) + "\n\n" + + return ret + + def update_edge_content(self, graph_id: int, head_uid: str, tail_uid: str, content: str): + + with closing(sqlite3.connect(self.db_file)) as connection: + with closing(connection.cursor()) as cursor: + + sql = "UPDATE dag_edges SET data=? WHERE graph_id=? AND head=? AND tail=?" + cursor.execute(sql, (content, graph_id, head_uid, tail_uid)) + + connection.commit() + + def clear(self): + + with closing(sqlite3.connect(self.db_file)) as connection: + with closing(connection.cursor()) as cursor: + + for table in ["dag_streams", "dag_edges", "dag_vertices"]: + sql = f"DELETE FROM {table}" + cursor.execute(sql, ) + + connection.commit() diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/constants.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/constants.py new file mode 100644 index 00000000..ea6e1f85 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/constants.py @@ -0,0 +1,61 @@ +# NOTE: The graph_id constant are part of keeper-dag as enums now. + +# This should the relationship between Keeper Vault record +RECORD_LINK_GRAPH_ID = 0 + +# The rules +DIS_RULES_GRAPH_ID = 10 + +# The discovery job history +DIS_JOBS_GRAPH_ID = 11 + +# Discovery infrastructure +DIS_INFRA_GRAPH_ID = 12 + +# The user-to-services graph +USER_SERVICE_GRAPH_ID = 13 + +PAM_DIRECTORY = "pamDirectory" +PAM_DATABASE = "pamDatabase" +PAM_MACHINE = "pamMachine" +PAM_USER = "pamUser" +LOCAL_USER = "local" + +PAM_RESOURCES = [ + PAM_DIRECTORY, + PAM_DATABASE, + PAM_MACHINE +] + +PAM_DOMAIN_CONFIGURATION = "pamDomainConfiguration" +PAM_AZURE_CONFIGURATION = "pamAzureConfiguration" +PAM_AWS_CONFIGURATION = "pamAwsConfiguration" +PAM_NETWORK_CONFIGURATION = "pamNetworkConfiguration" +PAM_GCP_CONFIGURATION = "pamGcpConfiguration" + +PAM_CONFIGURATIONS = [ + PAM_DOMAIN_CONFIGURATION, + PAM_AZURE_CONFIGURATION, + PAM_AWS_CONFIGURATION, + PAM_NETWORK_CONFIGURATION, + PAM_GCP_CONFIGURATION +] + +# These are configuration that could domain users. +# Azure included because of AADDS. +DOMAIN_USER_CONFIGS = [ + PAM_DOMAIN_CONFIGURATION, + PAM_AZURE_CONFIGURATION +] + +# The record types to process. +# The order defined the order the user will be presented the new discovery objects. +# The sort defined how the discovery objects for a record type are sorted and presented. +# Cloud-based users are presented first, then directories second. +# We want to prompt about users that may appear on machines before processing the machine. +VERTICES_SORT_MAP = { + PAM_USER: {"order": 1, "sort": "sort_infra_name", "item": "DiscoveryUser", "key": "user"}, + PAM_DIRECTORY: {"order": 1, "sort": "sort_infra_name", "item": "DiscoveryDirectory", "key": "host_port"}, + PAM_MACHINE: {"order": 2, "sort": "sort_infra_host", "item": "DiscoveryMachine", "key": "host"}, + PAM_DATABASE: {"order": 3, "sort": "sort_infra_host", "item": "DiscoveryDatabase", "key": "host_port"}, +} diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag.py new file mode 100644 index 00000000..ffa5f5aa --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag.py @@ -0,0 +1,1441 @@ + +from enum import Enum +import importlib +import json +import logging +import os +import sys +import traceback +from typing import Any, List, Optional, Tuple, Union + +from . import dag_utils, dag_crypto +from .dag_types import EdgeType, Ref, RefType, ENDPOINT_TO_GRAPH_ID_MAP, DAGData +from .dag_vertex import DAGVertex +from .struct.protobuf import DataStruct as ProtobufDataStruct +from .struct.default import DataStruct as DefaultDataStruct +from .exceptions import (DAGPathException, DAGDataException, DAGKeyException, DAGCorruptException, + DAGVertexException, DAGConfirmException, DAGVertexAlreadyExistsException, DAGEdgeException) +from .connection import ConnectionBase +from .__version__ import __version__ as dag_version +from ... import utils + +QueryValue = Union[list, dict, str, float, int, bool] + +class DAG: + + # Debug level. Increase to get finer debug messages. + DEBUG_LEVEL = 0 + + UID_KEY_BYTES_SIZE = 16 + UID_KEY_STR_SIZE = 22 + + # For the dot output, enum to text. + EDGE_LABEL = { + EdgeType.DATA: "DATA", + EdgeType.KEY: "KEY", + EdgeType.LINK: "LINK", + EdgeType.ACL: "ACL", + EdgeType.DELETION: "DELETION", + } + + def __init__(self, + conn: ConnectionBase, + record: Optional[object] = None, + key_bytes: Optional[bytes] = None, + name: Optional[str] = None, + read_endpoint: Optional[Union[str, Enum]] = None, + write_endpoint: Optional[Union[str, Enum]] = None, + graph_id: Optional[Union[int, Enum]] = None, + auto_save: bool = False, + history_level: int = 0, + logger: Optional[Any] = None, + debug_level: int = 0, + is_dev: bool = False, + vertex_type: RefType = RefType.PAM_NETWORK, + decrypt: bool = True, + fail_on_corrupt: bool = True, + data_requires_encryption: bool = False, + log_prefix: str = "GraphSync", + save_batch_count: Optional[int] = None, + agent: Optional[str] = None, + dedup_edges: bool = False): + + """ + Create a GraphSync instance. + + :param conn: Connection instance + :param record: If set, the key bytes will use the key bytes in the record. Overrides key_bytes. + :param key_bytes: If set, these key bytes will be used. + :param name: Optional name for the graph. + :param read_endpoint: Endpoint for reading from graph. Use this over `graph_id`. (i.e. graph-sync/pam ) + :param write_endpoint: Endpoint for writing to graph. Use this over `graph_id`. (i.e. graph-sync/pam ) + :param graph_id: Graph ID sets which graph to load for the graph. `endpoint` replaces this, but code is + backwards compatiable. + :param auto_save: Automatically save when modifications are performed. Default is False. + :param history_level: How much edge history to keep in memory. Default is 0, no history. + :param logger: Python logger instance to use for logging. + :param debug_level: Debug level; the higher the number will result in more debug information. + :param is_dev: Is the code running in a development environment? + :param vertex_type: The default vertex/ref type for the root vertex, if auto creating. + :param decrypt: Decrypt the graph; Default is TRUE + :param fail_on_corrupt: If unable to decrypt encrypted data, fail out. + :param data_requires_encryption: Data edges are already encrypted. Default is False. + :param log_prefix: Text prepended to the log messages. Handy if dealing with multiple graphs. + :param save_batch_count: The number of edges to save at one time. + :param agent: User Agent to send with web service requests. + :param dedup_edges: Remove modified edges if the same edge added before save. + :return: Instance of GraphSync + """ + + if logger is None: + logger = logging.getLogger() + self.logger = logger + if debug_level is None: + debug_level = int(os.environ.get("GS_DEBUG_LEVEL", os.environ.get("DAG_DEBUG_LEVEL", 0))) + + # Prevent duplicate edges to be added. + # The goal is to prevent unneeded edges. + # If warning is turned on, log dup and stacktrace. + self.dedup_edge = dag_utils.value_to_boolean(os.environ.get("GS_DEDUP_EDGES", dedup_edges)) + self.dedup_edge_warning = dag_utils.value_to_boolean(os.environ.get("GS_DEDUP_EDGES_WARN", False)) + + if self.dedup_edge and auto_save: + raise Exception("Cannot run dedup_edge and auto_save at the same time. The dedup_edge feature only works " + "in bulk saves.") + + self.debug_level = debug_level + self.log_prefix = log_prefix + + if save_batch_count is None or save_batch_count <= 0: + save_batch_count = 0 + self.save_batch_count = save_batch_count + + self.vertex_type = vertex_type + + self.data_requires_encryption = data_requires_encryption + self.decrypt = decrypt + self.fail_on_corrupt = fail_on_corrupt + + gs_is_dev = os.environ.get("GS_IS_DEV", os.environ.get("DAG_IS_DEV")) + if gs_is_dev is not None: + is_dev = dag_utils.value_to_boolean(gs_is_dev) + self.is_dev = is_dev + + # If the record is passed in, use the UID and key bytes from the record. + self.uid = None + if record is not None: + self.uid = conn.get_record_uid(record) + key_bytes = conn.get_key_bytes(record) + + self.key = key_bytes + + if key_bytes is None: + raise ValueError("Either the record or the key_bytes needs to be passed.") + + # If the UID is blank, use the key bytes to generate a UID + if self.uid is None: + self.uid = dag_crypto.generate_uid_str(key_bytes[:16]) + + if graph_id is None and (read_endpoint is None or write_endpoint is None): + raise ValueError("Either graph_id or read/write endpoints needs to be set.") + + # graph_id and endpoint determine how/where the graph is stored on the GraphSync service. + if graph_id is not None: + if isinstance(read_endpoint, Enum): + graph_id = ENDPOINT_TO_GRAPH_ID_MAP.get(read_endpoint.value) + self.graph_id = graph_id + + if read_endpoint is not None: + if isinstance(read_endpoint, Enum): + read_endpoint = read_endpoint.value + if write_endpoint is not None: + if isinstance(write_endpoint, Enum): + write_endpoint = write_endpoint.value + + self.read_endpoint = read_endpoint + self.write_endpoint = write_endpoint + + if name is None: + name = f"{self.log_prefix} ROOT" + self.name = name + + # The order of the vertices is important. + # The order creates the history. + # The web service will order edge by their edge_id + # Store in and array. + # The lookup table to make UID to DAGVertex easier. + # The integer is the index into the array. + self._vertices = [] + self._uid_lookup = {} + + # This is like the batch + self.origin_ref_value = dag_crypto.generate_uid_bytes(16) + self.origin_uid = dag_crypto.generate_uid_str(uid_bytes=self.origin_ref_value) + + # If True, any addition or changes will automatically be saved. + self.auto_save = auto_save + + # To auto save, both allow_auto_save and auto_save needs to be True. + # If the graph has not been saved before and the root vertex has not been connected, + # we want to disable auto save. + self._allow_auto_save = False + + # For big changes, we need a confirmation to save. + self.need_save_confirm = False + + # The last sync point after save. + self.last_sync_point = 0 + + # Amount of history to keep. + # The default is 0, which will keep all history. + # Setting to 1 will only keep the latest edges. + # Settings to 2 will keep the latest and prior edges. + # And so on. + self.history_level = history_level + + # If data was corrupt in the graph, the vertex UID will appear in this list. + self.corrupt_uids = [] + + self.conn = conn + + self.read_struct_obj: Union[ProtobufDataStruct, DefaultDataStruct] = ProtobufDataStruct() \ + if conn.use_read_protobuf else DefaultDataStruct() + self.write_struct_obj: Union[ProtobufDataStruct, DefaultDataStruct] = ProtobufDataStruct() \ + if conn.use_write_protobuf else DefaultDataStruct() + + self.agent = f"keeper-dag/{dag_version.__version__}" + if agent is not None: + self.agent += "; " + agent + + self.debug(f"save batch count is set to {self.save_batch_count}") + if self.is_dev is True: + self.debug("GraphSync is running in a development environment, vertex names will be included.") + self.debug(f"edge de-dup is {self.dedup_edge}", level=1) + self.debug(f"edge de-dup debug warning {self.dedup_edge_warning}", level=1) + self.debug(f"{self.log_prefix} key {self.key}", level=1) + self.debug(f"{self.log_prefix} UID {self.uid}", level=1) + self.debug(f"{self.log_prefix} UID HEX {dag_crypto.urlsafe_str_to_bytes(self.uid).hex()}", level=1) + + def __del__(self): + self.cleanup() + + def cleanup(self): + """ + Explicitly clean up the DAG and break circular references. + + This method allows users to manually trigger cleanup before the object + goes out of scope. This is useful in scenarios where you want to ensure + immediate memory release, such as: + - High-frequency DAG creation/destruction + - Long-running processes + - Memory-constrained environments + + After calling this method, the DAG object should not be used. + + Example: + dag = DAG(conn=conn, key_bytes=key) + # ... use the dag ... + dag.cleanup() # Explicitly clean up + del dag + """ + + try: + # Safely get the root vertex without creating a new one + if hasattr(self, '_vertices') and hasattr(self, 'uid') and hasattr(self, '_uid_lookup'): + if len(self._vertices) > 0 and self.uid in self._uid_lookup: + idx = self._uid_lookup[self.uid] + if idx < len(self._vertices): + root = self._vertices[idx] + if hasattr(root, 'clean_edges'): + root.clean_edges() + except (Exception,): + pass + finally: + # Always attempt to clear these collections, even if clean_edges() fails + try: + if hasattr(self, '_vertices'): + self._vertices.clear() + if hasattr(self, '_uid_lookup'): + self._uid_lookup.clear() + if hasattr(self, 'corrupt_uids'): + self.corrupt_uids.clear() + except (Exception,): + pass + + # Clear all collections to break circular references + self.read_struct_obj = None + del self.read_struct_obj + self.write_struct_obj = None + del self.write_struct_obj + self.conn = None + del self.conn + + def debug(self, msg: str, level: int = 0): + """ + Debug with granularity level. + + If the debug level is greater or equal to the level on the message, the message will be displayed. + + :param msg: Text debug message + :param level: Debug level of message + :return: + """ + + if self.debug_level >= level: + + msg = f"{self.log_prefix}: {msg}" + + if self.logger is not None: + self.logger.debug(msg) + else: + logging.debug(msg) + + def debug_stacktrace(self): + exc = sys.exc_info()[0] + # the last one would be full_stack() + stack = traceback.extract_stack()[:-1] + if exc is not None: + del stack[-1] + trc = 'Traceback (most recent call last):\n' + msg = trc + ''.join(traceback.format_list(stack)) + if exc is not None: + msg += ' ' + traceback.format_exc().lstrip(trc) + self.debug(msg) + + def __str__(self): + ret = f"GraphSync {self.uid}\n" + ret += f" python instance id: {id(self)}\n" + ret += f" name: {self.name}\n" + ret += f" key: {self.key}\n" + ret += f" vertices:\n" + for v in self.all_vertices: + ret += f" * {v.uid}, Keys: {v.keychain}, Active: {v.active}\n" + for e in v.edges: + if e.edge_type == EdgeType.DATA: + ret += f" + has a DATA edge" + if e.content is not None: + ret += ", has content" + else: + ret += f" + belongs to {e.head_uid}, {DAG.EDGE_LABEL.get(e.edge_type)}, {e.content}" + ret += "\n" + + return ret + + @property + def is_corrupt(self): + return len(self.corrupt_uids) > 0 + + @property + def allow_auto_save(self) -> bool: + """ + Return the flag indicating if auto save is allowed. + :return: + """ + + return self._allow_auto_save + + @allow_auto_save.setter + def allow_auto_save(self, value: bool): + """ + Set the ability to auto save. + :param value: True enables, False disables. + :return: + """ + + if value: + self.debug("ability to auto save has been ENABLED", level=2) + else: + self.debug("ability to auto save has been DISABLED", level=2) + + self._allow_auto_save = value + + @property + def origin_ref(self) -> Ref: + + """ + Return an instance of the origin reference for adding data + :return: + """ + + return Ref( + type=RefType.DEVICE, + value=self.origin_uid, + name=self.name if self.is_dev is True else None + ) + + @property + def has_graph(self) -> bool: + """ + Do we have any graph items? + + :return: True if there are vertices. False if no vertices. + """ + + return len(self._vertices) > 0 + + @property + def vertices(self) -> List[DAGVertex]: + """ + Get all active vertices + + :return: List of DAGVertex instance + """ + + return [ + vertex + for vertex in self._vertices + if vertex.active is True + ] + + @property + def all_vertices(self) -> List[DAGVertex]: + """ + Get all vertices + :return: List of DAGVertex instance + """ + + return self._vertices + + def get_vertex(self, key) -> Optional[DAGVertex]: + + """ + Get a single vertex. + + The key can be either a UID, path or name. + + The UID is most reliable since there can only be one per graph. + + The path is second reliable if it is set by the user. + It will find an edge with the path, the vertex that is the edge's tail is returned. + There is no unique constraint for the path. + You can have duplicates. + + The name is third, and not reliable. + The name only exists when the graph is created. + If loaded, the name will be None. + + :param key: A UID, path item, or name of a vertex. + :return: DAGVertex instance, if it exists. + """ + + if key is None: + return None + + # Is the key a UID? If so, return the vertex from the lookup. + if key in self._uid_lookup: + index = self._uid_lookup[key] + return self._vertices[index] + + # Is the key a path? + # We also want to include any deleted edges. + vertices = self.get_vertices_by_path_value(key, inc_deleted=True) + if len(vertices) > 0: + if len(vertices) > 1: + raise DAGPathException("Cannot get vertex using the path. Found multiple vertex that use the path.") + return vertices[0] + + # Is the key a name? This is a last resort. + for vertex in vertices: + if vertex.name == key: + return vertex + + return None + + @property + def get_root(self) -> Optional[DAGVertex]: + """ + Get the root vertex + + If the root vertex does not exist, it will create the vertex with a ref type of PAM_NETWORK. + + :return: + """ + root = self.get_vertex(self.uid) + if root is None: + root = self.add_vertex(uid=self.uid, name=self.name, vertex_type=self.vertex_type) + return root + + def vertex_exists(self, key: str) -> bool: + """ + Check if a vertex identified by the key exists. + :param key: UID, path, or name + :return: + """ + + return self.get_vertex(key) is not None + + def get_vertices_by_path_value(self, path: str, inc_deleted: bool = False) -> List[DAGVertex]: + """ + Find all vertices that have an edge that match the path + :param path: A string path value. This is a path to walk, just the value. + :param inc_deleted: Include deleted edges. + :return: List of DAGVertex + """ + results = [] + if inc_deleted: + vertices = self.all_vertices + else: + vertices = self.vertices + + for vertex in vertices: + for edge in vertex.edges: + if edge.path == path: + results.append(vertex) + return results + + def _sync(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: + + # The web service will send 500 items, if there is more the 'has_more' flag is set to True. + has_more = True + + # Make the web service call to set all the data + all_data = [] + while has_more: + # Load a page worth of items + + sync_query = self.read_struct_obj.sync_query(stream_id=self.uid, + sync_point=sync_point, + graph_id=self.graph_id) + + results = self.read_struct_obj.get_sync_result( + self.conn.sync( + sync_query=sync_query, + graph_id=self.graph_id, + endpoint=self.read_endpoint, + agent=self.agent + )) + + if results.syncPoint == 0: + return all_data, 0 + + all_data += results.data + + # The server will tell us if there is more data to get. + has_more = results.hasMore + + # The sync_point will indicate where we need to start the sync from. Think syncPoint > value + sync_point = results.syncPoint + + return all_data, sync_point + + def _load(self, sync_point: int = 0): + + """ + Load the DAG + + This will clear the existing graph. + It will make web services calls to get the fresh graph, which will return a list of edges. + With the list of edges, it will create vertices and connect them with the edges. + The content of the edges will remain encrypted. The 'encrypted' flag is set to True. + We need the entire graph structure before decrypting. + + We don't have to worry about keys at this point. We are just trying to get structure + and content in the right place. Nothing is decrypted here. + + :param sync_point: Where to load + """ + + # Clear the existing vertices. + self._vertices = [] # type: List[DAGVertex] + self._uid_lookup = {} # type: dict[str, int] + + self.debug("# SYNC THE GRAPH ##################################################################", level=1) + + # Make the web service call to set all the data + all_data, sync_point = self._sync(sync_point=sync_point) + + self.debug(" PROCESS the non-DATA edges", level=2) + + # Process the non-DATA edges + for data in all_data: + + # Skip all the DATA edge + edge_type = EdgeType.find_enum(data.type) + if edge_type == EdgeType.DATA: + continue + + # The ref the tail. It connects to stored in the vertex. + tail_uid = data.ref.value + + # The parentRef is the head. It's the arrowhead on the edge. For DATA edges, it will be None. + head_uid = None + if data.parentRef is not None: + head_uid = data.parentRef.value + + self.debug(f" * edge {edge_type}, tail {tail_uid} to head {head_uid}", level=3) + + # We want to store this edge in the Vertex with the same value/UID as the ref. + if not self.vertex_exists(tail_uid): + self.debug(f" * tail vertex {tail_uid} does not exists. create.", level=3) + self.add_vertex( + uid=tail_uid, + name=data.ref.name, + + # This will be 0/GENERAL right now. We do the lookup just in case things will change in the + # future. + vertex_type=RefType.find_enum(data.ref.type) + ) + + # Get the tail vertex. + tail = self.get_vertex(tail_uid) + + # This most likely is a DELETION edge of a DATA edge. + # Set the head to be the same as the tail. + if head_uid is None: + head_uid = tail_uid + + # If the head vertex doesn't exist, we need to create. + if not self.vertex_exists(head_uid): + self.debug(f" * head vertex {head_uid} does not exists. create.", level=3) + self.add_vertex( + uid=head_uid, + name=data.parentRef.name, + vertex_type=RefType.GENERAL + ) + # Get the head vertex, which will exist now. + head = self.get_vertex(head_uid) + self.debug(f" * tail {tail_uid} belongs to {head_uid}, " + f"edge type {edge_type}", level=3) + + if edge_type == EdgeType.DELETION: + tail.disconnect_from(head) + else: + content = data.content + if content is not None: + if data.content_is_base64: + content = utils.base64_url_decode(content) + + # Connect this vertex to the head vertex. It belongs to that head vertex. + tail.belongs_to( + vertex=head, + edge_type=edge_type, + content=content, + # ACL and LINK edges are not encrypted. + is_encrypted=False, + path=data.path, + modified=False, + from_load=True + ) + + self.debug("", level=2) + self.debug(" PROCESS the DATA edges", level=2) + + # Process the DATA edges + # We don't have to worry about vertex creation since they will all exist. + for data in all_data: + + # Only process the data edges. + edge_type = EdgeType.find_enum(data.type) + if edge_type != EdgeType.DATA: + continue + + # Get the tail vertex. + tail_uid = data.ref.value + # We want to store this edge in the Vertex with the same value/UID as the ref. + if not self.vertex_exists(tail_uid): + self.debug(f" * tail vertex {tail_uid} does not exists. create.", level=3) + self.add_vertex( + uid=tail_uid, + name=data.ref.name, + + # This will be 0/GENERAL right now. We do the lookup just in case things will change in the + # future. + vertex_type=RefType.find_enum(data.ref.type) + ) + tail = self.get_vertex(tail_uid) + + content = data.content + if content is not None: + if data.content_is_base64: + content = utils.base64_url_decode(content) + + self.debug(f" * DATA edge belongs to {tail.uid}", level=3) + tail.add_data( + content=content, + # Assume DATA is encrypted; it might not be but, we will handle that later. + is_encrypted=True, + path=data.path, + modified=False, + from_load=True, + ) + + self.debug("", level=1) + + return sync_point + + def _mark_deletion(self): + + """ + Mark vertices as deleted. + + Check each vertex to see if there is any non-DELETION edge connecting to another vertex. + If there are no edges, then the vertex is flagged at deleted. + + This is done to prevent the edges from being connected to a deleted vertex. + Also, to display deleted vertex in the DOT graph. + :return: + """ + + self.debug(" CHECK dag vertices to see if they are active", level=1) + for vertex in self.all_vertices: + + self.debug(f"check vertex {vertex.uid}", level=3) + found_edge_to_another_vertex = False + for edge in vertex.edges: + # Skip the DELETION and DATA edges. + if edge.edge_type == EdgeType.DELETION or edge.edge_type == EdgeType.DATA: + continue + + # Check if this edge has a matching DELETION edge. + # If it does not, this vertex cannot be deleted. + if not edge.is_deleted: + found_edge_to_another_vertex = True + break + + # If the vertex belongs to no vertex, and it not the root, then flag it for deletion. + if found_edge_to_another_vertex is False and vertex.uid != self.uid: + self.debug(f" * vertex is deleted", level=3) + vertex.active = False + + self.debug("", level=1) + + def _decrypt_keychain(self): + + """ + Decrypt KEY/ACL edges + + Part one is to decrypt the KEY and ACL edges. + To decrypt the edge, we need to walk up the edges until we can no longer. + If we get the point where we can't walk up any farther, we need to use the record key bytes. + While walking up, if we get to a keychain that has been decrypted, we return that keychain. + As we walk back, we can decrypt any keychain that is still encrypted. + The decrypt keychain is set in the vertex. + """ + + self.debug(" DECRYPT the dag KEY edges", level=1) + + def _get_keychain(v): + self.debug(f" * looking at {v.uid}", level=3) + + # If the vertex has a decrypted key, then return it. + if v.has_decrypted_keys is True: + self.debug(" found a decrypted keychain on vertex", level=3) + return v.keychain + + # Else we need KEY/ACL edge and get the key from the vertex that this vertex belongs to + found_key_edge = False + for e in v.edges: + if e.edge_type == EdgeType.KEY: + + self.debug(f" has edge that is a key, check head vertex {e.head_uid}", level=3) + head = self.get_vertex(e.head_uid) + keychain = _get_keychain(head) + + # No need to check if keychain exists. + # At default, it should contain the record bytes if no KEY/ACL edges existed for a vertex. + + self.debug(f" * decrypt {v.uid} with keys {keychain}", level=3) + was_able_to_decrypt = False + + # Try the keys in the keychain. One should be able to decrypt the content. + for key in keychain: + try: + # The edge will contain a single key. + # Adding a key to + self.debug(f" decrypt with key {key}", level=3) + content = dag_crypto.decrypt_aes(e.content, key) + self.debug(f" content {content}", level=3) + v.add_to_keychain(content) + self.debug(f" * vertex {v.uid} keychain is {v.keychain}", level=3) + was_able_to_decrypt = True + found_key_edge = True + break + except (Exception,): + self.debug(f" !! this is not the key", level=3) + + if not was_able_to_decrypt: + + # Flag that the edge is corrupt, flag that the vertex keychain is corrupt, + # and store vertex UID/tail UID. + # If we fail on corrupt keys, then raise exceptions. + e.corrupt = True + v.corrupt = True + self.corrupt_uids.append(v.uid) + if self.fail_on_corrupt: + raise DAGKeyException(f"Could not decrypt vertex {v.uid} keychain for edge path {e.path}") + return [] + + if found_key_edge: + return v.keychain + else: + self.debug(" * using record bytes", level=3) + return [self.key] + + for vertex in self.all_vertices: + if not vertex.has_key: + continue + self.debug(f"vertex {vertex.uid}, {vertex.has_key}, {vertex.has_decrypted_keys}", level=3) + vertex.keychain = _get_keychain(vertex) + self.debug(f" setting keychain to {vertex.keychain}", level=3) + + self.debug("", level=1) + + def _decrypt_data(self): + + """ + Decrypt DATA edges + + At this point, all the vertex should have an encrypted key. + This key is used to decrypt the DATA edge's content. + Walk each vertex and decrypt the DATA edge if there is a DATA edge. + """ + + self.debug(" DECRYPT the dag data", level=1) + for vertex in self.all_vertices: + if not vertex.has_data: + continue + self.debug(f"vertex {vertex.uid}, {vertex.keychain}", level=3) + + for edge in vertex.edges: + if edge.edge_type != EdgeType.DATA: + continue + + # If the vertex/KEY edge that tail is this vertex is corrupt, we cannot decrypt data. + if vertex.corrupt: + self.debug(f"the key for the DATA edge is corrupt for vertex {vertex.uid}; " + "cannot decrypt data.", level=3) + continue + + if not edge.is_encrypted: + raise ValueError("The content has already been decrypted.") + + content = edge.content + + self.debug(f" * enc safe content {content}", level=3) + self.debug(f" * enc {content}, enc key {vertex.keychain}", level=3) + able_to_decrypt = False + + keychain = vertex.keychain + + # Try the keys in the keychain. One should be able to decrypt the content. + for key in keychain: + try: + edge.content = dag_crypto.decrypt_aes(content, key) + able_to_decrypt = True + self.debug(f" * content {edge.content}", level=3) + break + except (Exception,): + self.debug(f" !! this is not the key", level=3) + + if not able_to_decrypt: + + # If the DATA edge requires encryption, throw error if we cannot decrypt. + if self.data_requires_encryption: + self.corrupt_uids.append(vertex.uid) + raise DAGDataException(f"The data edge {vertex.uid} could not be decrypted.") + + edge.content = content + edge.needs_encryption = False + self.debug(f" * edge is not encrypted or key is incorrect.") + + # Change the flag indicating that the content is in decrypted state. + edge.is_encrypted = False + + self.debug("", level=1) + + def _flag_as_not_modified(self): + + """ + Flag all edges a not modified. + + :return: + """ + + for vertex in self.all_vertices: + for edge in vertex.edges: + edge.modified = False + + def load(self, sync_point: int = 0) -> int: + + """ + Load data from the graph. + + The first step is to recreate the structure of the graph. + The second step is mark vertex as deleted. + The third step is to decrypt the KEY/ACL/DATA edges. + Forth is to flag all edges as not modified. + + :return: The sync point of the graph stream + """ + + # During the load, turn off auto save + self.allow_auto_save = False + + self.debug("== LOAD DAG ========================================================================", level=2) + sync_point = self._load(sync_point) + self.debug(f"sync point is {sync_point}") + self._mark_deletion() + if self.decrypt: + self._decrypt_keychain() + self._decrypt_data() + else: + self.logger.info("the DAG has not been decrypted, the decrypt flag was get to False") + self._flag_as_not_modified() + self.debug("====================================================================================", level=2) + + # We have loaded the graph, enable the ability to use auto save. + self.allow_auto_save = True + + self.last_sync_point = sync_point + + return sync_point + + def _make_delta_graph(self, duplicate_data: bool = True): + + self.debug("DELTA GRAPH", level=3) + modified_vertices = [] + for vertex in self.all_vertices: + found_modification = False + for edge in vertex.edges: + if edge.skip_on_save: + continue + if edge.modified: + found_modification = True + break + if found_modification: + modified_vertices.append(vertex) + if len(modified_vertices) == 0: + self.debug("nothing has been modified") + return + + self.debug(f"has {len(modified_vertices)} vertices", level=3) + + def _flag(v: DAGVertex): + + self.debug(f"check vertex {v.uid}", level=3) + if v.uid == self.uid: + self.debug(f" FOUND ROOT", level=3) + return True + + # Check if we have any of these edges in this order. + found_path = False + for edge_type in [EdgeType.KEY, EdgeType.ACL, EdgeType.LINK]: + seen = {} + for e in v.edges: + self.debug(f" checking {e.edge_type}, {v.uid} to {e.head_uid}", level=3) + is_deletion = None + if e.edge_type == edge_type: + self.debug(f" found {edge_type}", level=3) + next_vertex = self.get_vertex(e.head_uid) + + if is_deletion is None: + + # If the most recent edge a DELETION edge? + version, highest_edge = v.get_highest_edge_version(next_vertex.uid) + is_deletion = highest_edge.edge_type == EdgeType.DELETION + if is_deletion: + self.debug(f" highest deletion edge. will not mark any edges as modified", + level=3) + + found_path = _flag(next_vertex) + if found_path is True and seen.get(e.head_uid) is None: + self.debug(f" setting {v.uid}, {edge_type} active", level=3) + if not is_deletion: + e.modified = True + seen[e.head_uid] = True + else: + self.debug(f" edge is not {edge_type}", level=3) + + if found_path is True: + break + + # If we found a path, we may need to duplicate the DATA edge. + if found_path is True and duplicate_data is True: + for e in v.edges: + if e.edge_type == EdgeType.DATA: + e.modified = True + break + + return found_path + + self.logger.debug("BEGIN delta graph edge detection") + for modified_vertex in modified_vertices: + _flag(modified_vertex) + self.logger.debug("END delta graph edge detection") + + def save(self, confirm: bool = False, delta_graph: bool = False): + + """ + Save the graph + + We will not save if using the default graph. + + The save process will only save edges that have been flagged as modified, or are newly added. + The process will get the edges from all vertices. + The UID of the vertex is the tail UID of the edge. + For DATA edges, the key (first key in the keychain) will be used for encryption. + + If the web service takes too long or hangs, the batch_count can be used to reduce the amount the web service + needs to handle per request. If set to None or non-postivie value, it will not send in batches. + + :param confirm: Confirm save. + Only need this when deleting all vertices. + :param delta_graph: Make a standalone graph from the modifications. + Use sync points to load this graph. + + :return: + """ + + self.debug("== SAVE GRAPH ========================================================================", level=2) + + if self.is_corrupt: + self.logger.error(f"the graph is corrupt, there are problem UIDs: {','.join(self.corrupt_uids)}") + raise DAGCorruptException(f"Cannot save. Graph steam uid {self.uid}, " + f"graph {self.write_endpoint}:{self.graph_id} " + f"has corrupt vertices: {','.join(self.corrupt_uids)}") + + root_vertex = self.get_vertex(self.uid) + if root_vertex is None: + raise DAGVertexException("Cannot save. Could not find the root vertex.") + + if root_vertex.vertex_type != RefType.PAM_NETWORK and root_vertex.vertex_type != RefType.PAM_USER: + raise DAGVertexException("Cannot save. Root vertex type needs to be PAM_NETWORK or PAM_USER.") + + # Do we need to the 'confirm' parameter set to True? + # This is needed if the entire graph is being deleted. + if self.need_save_confirm is True and confirm is False: + raise DAGConfirmException("Cannot save. Confirmation is required.") + self.need_save_confirm = False + + if delta_graph: + self._make_delta_graph() + + data_list = [] + + def _add_data(vertex): + self.debug(f"processing vertex {vertex.uid}, key {vertex.key}, type {vertex.vertex_type}", level=3) + # The vertex UID and edge tail UID + uid = vertex.uid + for edge in vertex.edges: + + if edge.skip_on_save: + continue + + self.debug(f" * edge {edge.edge_type.value}, head {edge.head_uid}, tail {vertex.uid}", level=3) + + # If this edge is not modified, don't add to the data list to save. + if not edge.modified: + self.debug(f" not modified, not saving.", level=3) + continue + + content = edge.content + + # If we are decrypting the edge data, then we want to encrypt it when we save. + # Else, save the content as it is. + if self.decrypt: + if edge.edge_type == EdgeType.DATA: + self.debug(f" edge is data, encrypt data: {edge.needs_encryption}", level=3) + if isinstance(content, dict): + content = json.dumps(content).encode() + if isinstance(content, str): + content = content.encode() + + # If individual edges require encryption or all DATA edge require encryption, then encrypt + if edge.needs_encryption is True or self.data_requires_encryption is True: + self.debug(f" content {edge.content}, enc key {vertex.key}", level=3) + content = dag_crypto.encrypt_aes(content, vertex.key) + self.debug(f" enc content {content}", level=3) + + self.debug(f" enc safe content {content}", level=3) + elif edge.edge_type == EdgeType.KEY: + self.debug(f" edge is key or acl, encrypt key", level=3) + head_vertex = self.get_vertex(edge.head_uid) + key = head_vertex.key + if key is None: + self.debug(f" the edges head vertex {edge.head_uid} did not have a key. " + "using root dag key.", level=3) + key = self.key + self.debug(f" key {vertex.key}, enc key {key}", level=3) + content = dag_crypto.encrypt_aes(vertex.key, key) + elif edge.edge_type == EdgeType.ACL: + content = edge.content + else: + self.debug(f" edge is {edge.edge_type}", level=3) + + parent_vertex = self.get_vertex(edge.head_uid) + + if content is not None and len(content) > 65_535: + self.debug(f" !! vertex {vertex.uid} data edge is {len(content)}, over 64K.") + raise DAGDataException(f"vertex {vertex.uid} DATA edge is {len(content)} bytes. " + "This is too large for the MySQL BLOB (64K).") + + dag_data = self.write_struct_obj.data( + data_type=edge.edge_type, + content=content, + tail_uid=uid, + tail_ref_type=vertex.vertex_type, + tail_name=vertex.name if self.is_dev is True else None, + head_uid=edge.head_uid, + head_ref_type=parent_vertex.vertex_type, + head_name=parent_vertex.name if self.is_dev is True else None, + path=edge.path) + + data_list.append(dag_data) + + # Flag that this edge is no longer modified. + edge.modified = False + + # Add the root vertex first + _add_data(self.get_root) + + # Add the rest. + # Only add is the skip_save is False. + for v in self.all_vertices: + if v.skip_save is False: + if v.uid != self.uid: + _add_data(v) + + # Save the keys before the data. + # This is done to make sure the web service can figure out the stream id. + # By saving the keys before data, the structure of the graph is formed. + if len(data_list) > 0: + + if self.debug_level >= 4: + + self.debug("EDGE LIST") + self.debug("##############################################") + for data in data_list: + self.debug(f"{data.ref.value} -> {data.parentRef.value} ({data.type})") + self.debug("##############################################") + + self.debug(f"total list has {len(data_list)} items", level=0) + self.debug(f"batch {self.save_batch_count} edges", level=0) + + batch_num = 0 + while len(data_list) > 0: + + # If using batch add, then take the first batch_count items. + # Remove them from the data list + if self.save_batch_count > 0: + batch_list = data_list[:self.save_batch_count] + data_list = data_list[self.save_batch_count:] + + # Else take everything and clear the data list (else infinite loop) + else: + batch_list = data_list + data_list = [] + + # Little sanity check + if len(batch_list) == 0: + break + + self.debug(f"adding {len(batch_list)} edges, batch {batch_num}", level=0) + + payload = self.write_struct_obj.payload( + origin_ref=self.write_struct_obj.origin_ref( + origin_ref_value=self.origin_ref_value, + name=self.name if self.is_dev is True else None + ), + data_list=batch_list, + graph_id=self.graph_id + ) + + try: + self.conn.add_data(payload, + graph_id=self.graph_id, + endpoint=self.write_endpoint, + agent=self.agent) + except Exception as err: + self.logger.error(f"could not add data to graph for batch {batch_num}: {err}") + raise err + + batch_num += 1 + + # It's a POST that returns no data + else: + self.debug("data list was empty, not saving.", level=2) + + self.debug("====================================================================================", level=2) + + def do_auto_save(self): + # If allow_auto_save is False, we will not allow auto saving. + # On newly created graph, this will happen if the root vertex has not been connected. + # The root vertex/disconnect edge head is needed to get a proper stream ID. + if not self.allow_auto_save: + self.debug("cannot auto_save, allow_auto_save is False.", level=3) + return + if self.auto_save: + self.debug("... dag auto saving", level=1) + self.save() + + def add_vertex(self, name: Optional[str] = None, uid: Optional[str] = None, keychain: Optional[List[bytes]] = None, + vertex_type: RefType = RefType.GENERAL) -> DAGVertex: + + """ + Add a vertex to the graph. + + :param name: Name for the vertex. + :param uid: String unique identifier. + It's a 16bit hex value that is base64 encoded. + :param keychain: List if key bytes to use for encryption/description. This is set by the load/save method. + :param vertex_type: A RefType enumeration type. If blank, it will default to GENERAL. + :return: + """ + + if name is None: + name = uid + + vertex = DAGVertex( + name=name, + dag=self, + uid=uid, + keychain=keychain, + vertex_type=vertex_type + ) + if self.vertex_exists(vertex.uid): + raise DAGVertexAlreadyExistsException(f"Vertex {vertex.uid} already exists.") + + # Set the UID to array index lookup. + # This is where the vertex will be in the vertices list. + # Then append the vertex to the vertices list. + self._uid_lookup[vertex.uid] = len(self._vertices) + self._vertices.append(vertex) + + return vertex + + @property + def is_modified(self) -> bool: + for vertex in self.all_vertices: + for edge in vertex.edges: + if edge.modified is True: + return True + return False + + @property + def modified_edges(self): + edges = [] + for vertex in self.all_vertices: + for edge in vertex.edges: + if edge.modified is True: + edges.append(edge) + return edges + + def delete(self): + """ + Delete the entire graph. + + This will delete all the vertex, which will delete all the edges. + This will not automatically save. + The save method will need to be called. + The save will require the 'confirm' parameter to be set to True. + :return: + """ + for vertex in self.vertices: + vertex.delete() + self.need_save_confirm = True + + def _search(self, content: Any, value: QueryValue, ignore_case: bool = False): + + if isinstance(value, dict): + # If the object is not a dictionary, then it's not match + if not isinstance(content, dict): + return False + for next_key, next_value in value.items(): + if next_key not in content: + return False + if not self._search(content=content[next_key], + value=next_value, + ignore_case=ignore_case): + return False + return True + elif isinstance(value, list): + # If the object is not a dictionary, then it's not match + for next_value in value: + if self._search(content=content, + value=next_value, + ignore_case=ignore_case): + return True + return False + else: + content = str(content) + value = str(value) + if ignore_case: + content = content.lower() + value = value.lower() + + return value in content + + def search_content(self, query, ignore_case: bool = False): + results = [] + for vertex in self.vertices: + if vertex.has_data is False or vertex.active is False: + continue + content = vertex.content + if isinstance(query, bytes): + if query == content: + results.append(vertex) + elif isinstance(query, str): + try: + content = content.decode() + if query in content: + results.append(vertex) + continue + + except (Exception,): + pass + elif isinstance(query, dict): + try: + content = content.decode() + content = json.loads(content) + search_result = self._search(content, value=query, ignore_case=ignore_case) + if search_result: + results.append(vertex) + except (Exception,): + pass + else: + raise ValueError("Query is not an accepted type.") + return results + + def walk_down_path(self, path: Union[str, List[str]]) -> Optional[DAGVertex]: + + """ + Walk the vertices using the path and return the vertex starting at root vertex. + + :param path: An array of path string, or string where the path is joined with a "/" (i.e., think URL) + :return: DAGVertex is the path completes, None is failure. + """ + + self.debug("walking path starting at the root vertex", level=2) + vertex = self.get_vertex(self.uid) + return vertex.walk_down_path(path) + + def edge_count(self, only_active: bool = True, only_modified: bool = False) -> int: + """ + Return number of edges in graph. + + Edges that have neen flags as dups will not be inclued. + + :param only_active: Default True. If True, only edges that are active (not DELETION) are counted. + :param only_modified: Default False. If Trtue, only edges that are new or have been modified are counted. + :return: + """ + + count = 0 + for v in self.all_vertices: + for e in v.edges: + if e.skip_on_save: + continue + if (only_active and not e.active) or (only_modified and not e.modified): + continue + count += 1 + + return count + + def to_dot(self, graph_format: str = "svg", show_hex_uid: bool = False, + show_version: bool = True, show_only_active: bool = False): + + """ + Generate a graphviz Gigraph in DOT format that is marked up. + + :param graph_format: + :param show_hex_uid: + :param show_version: + :param show_only_active: + :return: + """ + + try: + mod = importlib.import_module("graphviz") + except ImportError: + raise Exception("Cannot to_dot(), graphviz module is not installed.") + + dot = getattr(mod, "Digraph")(comment=f"GraphSync for {self.name}", format=graph_format) + dot.attr(rankdir='BT') + + for v in self._vertices: + if show_only_active is True and v.active is False: + continue + if not v.corrupt: + fillcolor = "white" + if not v.active: + fillcolor = "grey" + label = f"uid={v.uid}" + if v.name is not None and v.name != v.uid: + label += f"\\nname={v.name}" + if show_hex_uid: + label += f"\\nhex={dag_crypto.urlsafe_str_to_bytes(v.uid).hex()}" + else: + fillcolor = "red" + label = f"{v.uid} (CORRUPT)" + + dot.node(v.uid, label, fillcolor=fillcolor, style="filled") + for edge in v.edges: + if edge.skip_on_save: + continue + + if not edge.corrupt: + color = "grey" + style = "solid" + + # To reduce the number of edges, only show the active edges + if edge.active: + color = "black" + style = "bold" + elif show_only_active: + continue + + # If the vertex is not active, gray out the DATA edge + if not edge.active: + color = "grey" + + if edge.edge_type == EdgeType.DELETION: + style = "dotted" + + label = DAG.EDGE_LABEL.get(edge.edge_type) + if label is None: + label = "UNK" + if edge.path is not None and edge.path != "": + label += f"\\npath={edge.path}" + if show_version: + label += f"\\ne{edge.version}" + # tail, head (arrow side), label + else: + color = "red" + style = "solid" + label = f"{DAG.EDGE_LABEL.get(edge.edge_type)} (CORRUPT)" + + dot.edge(v.uid, edge.head_uid, label, style=style, fontcolor=color, color=color) + + return dot + + def to_dot_raw(self, graph_format: str = "svg", sync_point: int = 0, rank_dir="BT"): + + """ + Generate a graphviz Gigraph in DOT format that is not (heavily) marked up. + + :param graph_format: + :param sync_point: + :param rank_dir: + :return: + """ + + try: + mod = importlib.import_module("graphviz") + except ImportError: + raise Exception("Cannot to_dot(), graphviz module is not installed.") + + dot = getattr(mod, "Digraph")(comment=f"GraphSync for {self.name}", format=graph_format) + dot.attr(rankdir=rank_dir) + + all_data, sync_point = self._sync(sync_point=sync_point) + + for edge in all_data: + edge_type = edge.type + tail_uid = edge.ref.value + dot.node(tail_uid, tail_uid) + if edge.parentRef is not None: + head_uid = edge.parentRef.value + dot.edge(tail_uid, head_uid, edge_type) + else: + dot.edge(tail_uid, tail_uid, edge_type) + return dot diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_crypto.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_crypto.py new file mode 100644 index 00000000..890e23ca --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_crypto.py @@ -0,0 +1,52 @@ + +import base64 +from typing import Optional, Union +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +import os + + +def encrypt_aes(data: bytes, key: bytes, iv: bytes = None) -> bytes: + aesgcm = AESGCM(key) + iv = iv or os.urandom(12) + enc = aesgcm.encrypt(iv, data, None) + return iv + enc + + +def decrypt_aes(data: bytes, key: bytes) -> bytes: + aesgcm = AESGCM(key) + return aesgcm.decrypt(data[:12], data[12:], None) + + +def bytes_to_urlsafe_str(b: Union[str, bytes]) -> str: + """ + Convert bytes to a URL-safe base64 encoded string. + + Args: + b (bytes): The bytes to be encoded. + + Returns: + str: The URL-safe base64 encoded representation of the input bytes. + """ + if isinstance(b, str): + b = b.encode() + + return base64.urlsafe_b64encode(b).decode().rstrip('=') + + +def generate_random_bytes(length: int) -> bytes: + return os.urandom(length) + + +def generate_uid_bytes(length: int = 16) -> bytes: + return generate_random_bytes(length) + + +def generate_uid_str(uid_bytes: Optional[bytes] = None) -> str: + if uid_bytes is None: + uid_bytes = generate_uid_bytes() + return bytes_to_urlsafe_str(uid_bytes) + + +def urlsafe_str_to_bytes(s: str) -> bytes: + b = base64.urlsafe_b64decode(s + '==') + return b \ No newline at end of file diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_edge.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_edge.py new file mode 100644 index 00000000..408c67ea --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_edge.py @@ -0,0 +1,285 @@ +import json +import logging +from typing import Optional, Union, Any, TYPE_CHECKING + +import pydantic +from ..keeper_dag.dag_types import EdgeType +from ..keeper_dag.exceptions import DAGContentException + +if TYPE_CHECKING: + from ..keeper_dag.dag_vertex import DAGVertex + +class DAGEdge: + def __init__(self, + vertex: "DAGVertex", + edge_type: EdgeType, + head_uid: str, + version: int = 0, + content: Optional[Any] = None, + path: Optional[str] = None, + modified: bool = True, + block_content_auto_save: bool = False, + is_serialized: bool = False, + is_encrypted: bool = False, + needs_encryption: bool = False): + """ + Create an instance of DAGEdge. + + A primary key of the edge the vertex UID, the head UID, and edge_type. + + :param vertex: The DAGVertex instance that owns these edges. + :param edge_type: The enumeration EdgeType. Indicate the type of the edge. + :param head_uid: The vertex uid that has this edge's vertex. The vertex uid that the edge arrow points at. + :param version: Version of this edge. + :param content: The content of this edge. + :param path: Short tag about this edge. Do + :param modified: + :param block_content_auto_save: + :param from_load: Is this being called from the load() method? + :param is_serialized: From the load, is the content serialized from to a base64 string? + :param needs_encryption: Flag to indicate if the content needs to be encrypted. + :return: An instance of DAGEdge + """ + + # This is the vertex that owns this edge. + self.vertex = vertex + self.edge_type = edge_type + self.head_uid = head_uid + + # Flag to indicate if the edge has been modified. Used to determine if the edge should be part of saved data. + # Set this before setting the content, else setting the content will cause an auto save. + self._modified = None + self.modified = modified + + # Should this edge be skipped when saving. + # This could happen if we create a duplicate new or modified edge. + # We want to only save the newest duplicated edge, so skip prior ones. + self.skip_on_save: bool = False + + # Block auto save in the content setter. + # When creating an edge, don't save until the edge is added to the edge list. + self.block_content_auto_save = block_content_auto_save + + # Does this edge's content need encryption? + self.needs_encryption = needs_encryption + + # If the content is being populated from a the load() method, and the edge type is a KEY or DATA, then the + # content will be encrypted (str). + # We want to keep a str, unless KEYs are decrypted. + + # If the edge data need encryption, is _content, currently encrypted. + self.is_encrypted = is_encrypted + + # If the content could not be decrypted, set + self.corrupt = False + + # Is the content base64 encoded? + # For JSON non-DATA edges, it will be deserialized. + # For JSON DATA edges, it will be serialized and the decryption will deserialize it. + # For Protobuf all edges are deserialized; Protobuf does not serialize bytes. + self.is_serialized = is_serialized + + self._content = None # type: Optional[Any] + self.content = content + self.path = path + + self.version = version + + # If a higher version edge exists, this will be False. + # If True, this is the highest edge. + self.active = True + + def __str__(self) -> str: + return f"" + + def debug(self, msg, level=0): + self.vertex.dag.debug(msg, level=level) + + @property + def modified(self): + return self._modified + + @modified.setter + def modified(self, value): + if value is True: + self.debug(f"vertex {self.vertex.uid}, type {self.vertex.dag.__class__.EDGE_LABEL.get(self.edge_type)}, " + f"head {self.head_uid} has been modified", level=5) + else: + self.debug(f"vertex {self.vertex.uid}, type {self.vertex.dag.__class__.EDGE_LABEL.get(self.edge_type)}, " + f"head {self.head_uid} had modified RESET", level=5) + self._modified = value + + @property + def content(self) -> Optional[Union[str, bytes]]: + """ + Get the content of the edge. + + If the content is a str, then the content is encrypted. + """ + + return self._content + + @property + def content_as_dict(self) -> Optional[dict]: + """ + Get the content from the DATA edge as a dictionary. + :return: Content as a dictionary. + """ + + if self.is_encrypted: + raise DAGContentException("The content is still encrypted.") + + content = self.content + if content is not None: + try: + content = json.loads(content) + except Exception as err: + raise DAGContentException(f"Cannot decode JSON. Is the content a dictionary? : {err}") + return content + + @property + def content_as_str(self) -> Optional[str]: + """ + Get the content from the DATA edge as string + :return: + """ + + if self.is_encrypted: + raise DAGContentException("The content is still encrypted.") + + content = self.content + try: + content = content.decode() + except (Exception,): + pass + return content + + def content_as_object(self, + meta_class: pydantic._internal._model_construction.ModelMetaclass) -> Optional[pydantic.BaseModel]: + """ + Get the content as a pydantic based object. + + :param meta_class: The class to return + :return: + """ + + if self.is_encrypted: + raise DAGContentException("The content is still encrypted.") + + content = self.content_as_str + if content is not None: + content = meta_class.model_validate_json(self.content_as_str) + return content + + @content.setter + def content(self, value: Any): + + """ + Set the content in the edge. + + The content should be stored as bytes. + If the encrypted flag is set, the content will be stored as is. + Content that is a str type is encrypted data (A Base64, AES encrypted bytes, str) + """ + + self.debug(f"vertex {self.vertex.uid}, type {self.vertex.dag.__class__.EDGE_LABEL.get(self.edge_type)}, " + f"head {self.head_uid} setting content", level=2) + + # If the data is encrypted, set it. + # Don't try to make it bytes. + # Also don't set the modified flag to True. + if self.is_encrypted: + self.debug(" content is encrypted.", level=3) + self._content = value + return + + if self._content is not None: + raise DAGContentException("Cannot update existing content. Use add_data() to change the content.") + + if isinstance(value, dict): + value = json.dumps(value) + + # Is this a Pydantic based class? + if hasattr(value, "model_dump_json"): + value = value.model_dump_json() + + if isinstance(value, str): + value = value.encode() + + self._content = value + + def delete(self): + """ + Delete the edge. + + Deleting an edge does not remove the existing edge. + It will create another edge with the same tail and head, but will be type DELETION. + """ + + # If already inactive, return + if not self.active: + return + + version, _ = self.vertex.get_highest_edge_version(head_uid=self.head_uid) + + # Flag all other edges as inactive. + for edge in self.vertex.edges: + edge.active = False + if self.vertex.dag.dedup_edge and edge.modified: + edge.skip_on_save = True + + # Check if the all the edges for this vertex are all newly added/modified for the session. + # If they are all new, then we don't need to delete anything; we just don't add the edges. + all_modified_edges = True + for edge in self.vertex.edges: + if not edge.modified: + all_modified_edges = False + break + + # If not all the edges are modified, then we actually want to add a deletion edge. + # There is actually something to delete. + if not all_modified_edges: + self.vertex.edges.append( + DAGEdge( + vertex=self.vertex, + edge_type=EdgeType.DELETION, + head_uid=self.head_uid, + version=version + 1 + ) + ) + + # Perform the DELETION edges save in one batch. + # Get the current allowed auto save state, and disable auto save. + current_allow_auto_save = self.vertex.dag.allow_auto_save + self.vertex.dag.allow_auto_save = False + + if not self.vertex.belongs_to_a_vertex: + self.vertex.delete(ignore_vertex=self.vertex) + + self.vertex.dag.allow_auto_save = current_allow_auto_save + self.vertex.dag.do_auto_save() + + @property + def is_deleted(self) -> bool: + """ + Does this edge have a DELETION edge that has the same head? + + This should be used to check in a non-DELETION edge type has a matching DELETION edge. + :return: + """ + + # We shouldn't be checking the DELETION edge if it deleted. + # Throw some info message to make sure the coder knows their code might be something foolish. + if self.edge_type == EdgeType.DELETION: + logging.info(f"The edge is_deleted() just check if the DELETION edge is DELETION " + f"for vertex {self.vertex.uid}, head UID {self.head_uid}. Returned True, but code should " + "not be checking this edge.") + return True + + # Check the other edges for this vertex for an active DELETION-edge type. + for edge in self.vertex.edges: + if edge.edge_type == EdgeType.DELETION and edge.head_uid == self.head_uid and edge.active is True: + return True + + return False diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_sort.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_sort.py new file mode 100644 index 00000000..6e6cab73 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_sort.py @@ -0,0 +1,122 @@ + + +import functools +import logging +import re +from typing import TYPE_CHECKING, List, Optional, Union +from .constants import VERTICES_SORT_MAP +from .dag_vertex import DAGVertex +from .dag_types import DiscoveryObject +Logger = Union[logging.RootLogger, logging.Logger] +if TYPE_CHECKING: + from .dag_vertex import DAGVertex + + +def sort_infra_name(vertices: List[DAGVertex]) -> List[DAGVertex]: + """ + Sort the vertices by name in ascending order. + """ + + def _sort(t1: DAGVertex, t2: DAGVertex): + t1_name = t1.content_as_dict.get("name") + t2_name = t2.content_as_dict.get("name") + if t1_name < t2_name: + return -1 + elif t1_name > t2_name: + return 1 + else: + return 0 + + return sorted(vertices, key=functools.cmp_to_key(_sort)) + + +def sort_infra_host(vertices: List[DAGVertex]) -> List[DAGVertex]: + """ + Sort the vertices by host name. + + Host name should appear first in ascending order. + IP should appear second in ascending order. + + """ + + def _is_ip(host: str) -> bool: + if re.match(r'^\d+\.\d+\.\d+\.\d+', host) is not None: + return True + return False + + def _make_ip_number(ip: str) -> int: + ip_port = ip.split(":") + parts = ip_port[0].split(".") + value = "" + for part in parts: + value += part.zfill(3) + return int(value) + + def _sort(t1: DAGVertex, t2: DAGVertex): + t1_name = t1.content_as_dict.get("name") + t2_name = t2.content_as_dict.get("name") + + # Both names are ip addresses + if _is_ip(t1_name) and _is_ip(t2_name): + t1_num = _make_ip_number(t1_name) + t2_num = _make_ip_number(t2_name) + + if t1_num < t2_num: + return -1 + elif t1_num > t2_num: + return 1 + else: + return 0 + + # T1 is an IP, T2 is a host name + elif _is_ip(t1_name) and not _is_ip(t2_name): + return 1 + # T2 is not an IP and T2 is an IP + elif not _is_ip(t1_name) and _is_ip(t2_name): + return -1 + # T1 and T2 are host name + else: + if t1_name < t2_name: + return -1 + elif t1_name > t2_name: + return 1 + else: + return 0 + + return sorted(vertices, key=functools.cmp_to_key(_sort)) + + +def sort_infra_vertices(current_vertex: DAGVertex, logger: Optional[logging.Logger] = None) -> dict: + + if logger is None: + logger = logging.getLogger() + + # Make a map, record type to list of vertices (of that record type) + record_type_to_vertices_map = {k: [] for k, v in VERTICES_SORT_MAP.items()} + + # Collate the vertices into a record type lookup. + vertices = current_vertex.has_vertices() + logger.debug(f" found {len(vertices)} vertices") + for vertex in vertices: + if vertex.active is True: + content = DiscoveryObject.get_discovery_object(vertex) + logger.debug(f" * {content.description}") + for vertex in vertices: + if vertex.active is False: + logger.debug(" vertex is not active") + continue + # We can't load into a pydantic object since Pydantic has a problem with Union type. + # We only want the record type, so it is too much work to try to get into an object. + content_dict = vertex.content_as_dict + record_type = content_dict.get("record_type") + if record_type in record_type_to_vertices_map: + record_type_to_vertices_map[record_type].append(vertex) + + # Sort the vertices for each record type. + for k, v in VERTICES_SORT_MAP.items(): + if v["sort"] == "sort_infra_name": + record_type_to_vertices_map[k] = sort_infra_name(record_type_to_vertices_map[k]) + elif v["sort"] == "sort_infra_host": + record_type_to_vertices_map[k] = sort_infra_host(record_type_to_vertices_map[k]) + + return record_type_to_vertices_map diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_types.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_types.py new file mode 100644 index 00000000..512736bf --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_types.py @@ -0,0 +1,825 @@ +import base64 +import datetime +from enum import Enum +import json +import time +from pydantic import BaseModel +from typing import Any, List, Optional, Union, TYPE_CHECKING + +from . import dag_crypto + +if TYPE_CHECKING: + from .dag_vertex import DAGVertex + +class BaseEnum(Enum): + + @classmethod + def find_enum(cls, value: Union[Enum, str, int], default: Optional[Enum] = None): + if value is not None: + for e in cls: + if e == value or e.value == value: + return e + if hasattr(cls, str(value).upper()): + return getattr(cls, value.upper()) + return default + + +class PamGraphId(BaseEnum): + PAM = 0 + DISCOVERY_RULES = 10 + DISCOVERY_JOBS = 11 + INFRASTRUCTURE = 12 + SERVICE_LINKS = 13 + + +class PamEndpoints(BaseEnum): + PAM = "/graph-sync/pam" + DISCOVERY_RULES = "/graph-sync/discovery_rules" + DISCOVERY_JOBS = "/graph-sync/discovery_jobs" + INFRASTRUCTURE = "/graph-sync/infrastructure" + SERVICE_LINKS = "/graph-sync/service_links" + + +ENDPOINT_TO_GRAPH_ID_MAP = { + PamEndpoints.PAM.value: PamGraphId.PAM.value, + PamEndpoints.DISCOVERY_RULES.value: PamGraphId.DISCOVERY_RULES.value, + PamEndpoints.DISCOVERY_JOBS.value: PamGraphId.DISCOVERY_JOBS.value, + PamEndpoints.INFRASTRUCTURE.value: PamGraphId.INFRASTRUCTURE.value, + PamEndpoints.SERVICE_LINKS.value: PamGraphId.SERVICE_LINKS.value, +} + + +class SyncQuery(BaseModel): + streamId: Optional[str] = None # base64 of a user's ID who is syncing. + deviceId: Optional[str] = None + syncPoint: Optional[int] = None + graphId: Optional[int] = 0 + + +class RefType(BaseEnum): + # 0 + GENERAL = "general" + # 1 + USER = "user" + # 2 + DEVICE = "device" + # 3 + REC = "rec" + # 4 + FOLDER = "folder" + # 5 + TEAM = "team" + # 6 + ENTERPRISE = "enterprise" + # 7 + PAM_DIRECTORY = "pam_directory" + # 8 + PAM_MACHINE = "pam_machine" + # 9 + PAM_DATABASE = "pam_database" + # 10 + PAM_USER = "pam_user" + # 11 + PAM_NETWORK = "pam_network" + # 12 + PAM_BROWSER = "pam_browser" + # 13 + CONNECTION = "connetion" + # 14 + WORKFLOW = "workflow" + # 15 + NOTIFICATION = "notification" + # 16 + USER_INFO = "user_info" + # 17 + TEAM_INFO = "team_info" + # 18 + ROLE = "role" + + def __str__(self): + return self.value + + +class EdgeType(BaseEnum): + + """ + DAG data type enum + + * DATA - encrypted data + * KEY - encrypted key + * LINK - like a key, but not encrypted + * ACL - unencrypted set of access control flags + * DELETION - removal of the previous edge at the same coordinates + * DENIAL - an element that was shared through graph relationship, can be explicitly denied + * UNDENIAL - negates the effect of denial, bringing back the share + + """ + DATA = "data" + KEY = "key" + LINK = "link" + ACL = "acl" + DELETION = "deletion" + DENIAL = "denial" + UNDENIAL = "undenial" + + # To store discovery, you would need data and key. To store relationships between records after the discovery + # data was converted, you use Link. + + def __str__(self) -> str: + return str(self.value) + + +class Ref(BaseModel): + type: RefType + value: str + name: Optional[str] = None + + +class DAGData(BaseModel): + type: EdgeType + ref: Ref + parentRef: Optional[Ref] = None + content: Optional[str] = None + path: Optional[str] = None + + +class DataPayload(BaseModel): + origin: Ref + dataList: List + graphId: Optional[int] = 0 + + +class SyncDataItem(BaseModel): + ref: Ref + parentRef: Optional[Ref] = None + content: Optional[str] = None + content_is_base64: bool = True + type: Optional[str] = None + path: Optional[str] = None + deletion: Optional[bool] = False + + +class SyncData(BaseModel): + syncPoint: int + data: List[SyncDataItem] + hasMore: bool + + +class RecordField(BaseModel): + type: str + label: Optional[str] = None + value: List[Any] = [] + required: bool = False + + +class UserAclRotationSettings(BaseModel): + # Base64 JSON schedule + schedule: Optional[str] = "" + + # Base64 JSON, encrypted + pwd_complexity: Optional[str] = "" + + disabled: bool = False + + # If true, do not rotate the username/password on remote system, if it exists. + noop: bool = False + + # A list of SaaS Record configuration records. + saas_record_uid_list: List[str] = [] + + def set_pwd_complexity(self, complexity: Union[dict, str, bytes], record_key_bytes: bytes): + if isinstance(complexity, dict): + complexity = json.dumps(complexity) + if isinstance(complexity, str): + complexity = complexity.encode() + + if not isinstance(complexity, bytes): + raise ValueError("The complexity is not a dictionary, string or is bytes.") + + self.pwd_complexity = base64.b64encode(dag_crypto.encrypt_aes(complexity, record_key_bytes)).decode() + + def get_pwd_complexity(self, record_key_bytes: bytes) -> Optional[dict]: + if self.pwd_complexity is None or self.pwd_complexity == "": + return None + complexity_enc_bytes = base64.b64decode(self.pwd_complexity.encode()) + complexity_bytes = dag_crypto.decrypt_aes(complexity_enc_bytes, record_key_bytes) + return json.loads(complexity_bytes) + + def set_schedule(self, schedule: Union[dict, str]): + if isinstance(schedule, dict): + schedule = json.dumps(schedule) + self.schedule = schedule + + def get_schedule(self) -> Optional[dict]: + if self.pwd_complexity is None or self.pwd_complexity == "": + return None + return json.loads(self.schedule) + + +class UserAcl(BaseModel): + # Is this user's password/private key managed by this resource? + # This should be unique for all the ACL edges of this user vertex; only one ACL edge should have a True value. + belongs_to: bool = False + + # Is this user an admin for the resource? + # This can be set True for multiple ACL edges; a user can be admin on multiple resources. + is_admin: bool = False + + # Is this user a cloud-based user? + # This will only be True if the ACL of the PAM User connects to a configuration vertex. + is_iam_user: Optional[bool] = False + + rotation_settings: Optional[UserAclRotationSettings] = None + + @staticmethod + def default(): + """ + Make an empty UserAcl that contains all the default values for the attributes. + """ + return UserAcl( + rotation_settings=UserAclRotationSettings() + ) + +class DiscoveryItem(BaseModel): + pass + + +class DiscoveryConfiguration(DiscoveryItem): + """ + This is very general. + We are not going to make a class for each configuration/provider. + Populate a dictionary for the important information (i.e., Network CIDR) + """ + type: str + info: dict + + # Configurations never allows an admin user. + # This should always be False. + allows_admin: bool = False + + +class DiscoveryUser(DiscoveryItem): + user: Optional[str] = None + dn: Optional[str] = None + database: Optional[str] = None + managed: bool = False + + # These are for directory services. + active: bool = True + expired: bool = False + source: Optional[str] = None + + # Normally these do not get set, except for the access_user. + password: Optional[str] = None + private_key: Optional[str] = None + private_key_passphrase: Optional[str] = None + + # Simple flag, for access user in discovery, that states could connect with creds. + # Local connection might not have passwords, so this is our flag to indicate that the user connected. + could_login: Optional[bool] = False + + +class FactsDirectory(BaseModel): + domain: str + software: Optional[str] = None + login_format: Optional[str] = None + + +class FactsId(BaseModel): + machine_id: Optional[str] = None + product_id: Optional[str] = None + board_serial: Optional[str] = None + + +class FactsNameUser(BaseModel): + name: str + user: str + + +class Facts(BaseModel): + name: Optional[str] = None + + # For devices + make: Optional[str] = None + model: Optional[str] = None + + directories: List[FactsDirectory] = [] + id: Optional[FactsId] = None + services: List[FactsNameUser] = [] + tasks: List[FactsNameUser] = [] + iis_pools: List[FactsNameUser] = [] + + @property + def has_services(self): + return self.services is not None and len(self.services) > 0 + + @property + def has_tasks(self): + return self.tasks is not None and len(self.tasks) > 0 + + @property + def has_iis_pools(self): + return self.iis_pools is not None and len(self.iis_pools) > 0 + + @property + def has_service_items(self): + return self.has_services or self.has_tasks or self.has_iis_pools + + +class DiscoveryMachine(DiscoveryItem): + host: str + ip: str + port: Optional[int] = None + os: Optional[str] = None + provider_region: Optional[str] = None + provider_group: Optional[str] = None + is_gateway: bool = False + allows_admin: bool = True + admin_reason: Optional[str] = None + facts: Optional[Facts] = None + + +class DiscoveryDatabase(DiscoveryItem): + host: str + ip: str + port: int + type: str + use_ssl: bool = False + database: Optional[str] = None + provider_region: Optional[str] = None + provider_group: Optional[str] = None + allows_admin: bool = True + admin_reason: Optional[str] = None + + +class DiscoveryDirectory(DiscoveryItem): + host: str + ip: str + ips: List[str] = [] + port: int + type: str + use_ssl: bool = False + provider_region: Optional[str] = None + provider_group: Optional[str] = None + allows_admin: bool = True + admin_reason: Optional[str] = None + + +class DiscoveryObject(BaseModel): + uid: str + id: str + object_type_value: str + record_uid: Optional[str] = None + parent_record_uid: Optional[str] = None + record_type: str + fields: List[RecordField] + ignore_object: bool = False + action_rules_result: Optional[str] = None + admin_uid: Optional[str] = None + shared_folder_uid: Optional[str] = None + name: str + title: str + description: str + notes: List[str] = [] + error: Optional[str] = None + stacktrace: Optional[str] = None + + # If the object is missing, this will show a timestamp on when it went missing. + missing_since_ts: Optional[int] = None + + # Should this object be deleted? This does not prevent user from deleting, but prevents automated processed from + # deleting. + allow_delete: bool = False + + # This is not the official admin. + # This is the user discovery used to access to the resource. + # This will be used to help the user create an admin user. + access_user: Optional[DiscoveryUser] = None + + # Specific information for a record type. + item: Union[DiscoveryConfiguration, DiscoveryUser, DiscoveryMachine, DiscoveryDatabase, DiscoveryDirectory] + + @property + def record_exists(self): + return self.record_uid is not None + + def get_field_value(self, label): + for field in self.fields: + if field.label == label or field.type == label: + value = field.value + if len(value) == 0: + return None + return field.value[0] + return None + + def set_field_value(self, label, value): + if not isinstance(value, list): + value = [value] + for field in self.fields: + if field.label == label or field.type == label: + field.value = value + return + raise ValueError(f"Cannot not find field with label {label}") + + @staticmethod + def get_discovery_object(vertex: "DAGVertex") -> "DiscoveryObject": + """ + Get DiscoveryObject with correct item instance. + + Pydantic doesn't like Unions on the item attribute. + Item needs to be validated using the correct class. + + :param vertex: + :return: + """ + + mapping = { + "pamUser": DiscoveryUser, + "pamDirectory": DiscoveryDirectory, + "pamMachine": DiscoveryMachine, + "pamDatabase": DiscoveryDatabase + } + + content_dict = vertex.content_as_dict + + if content_dict is None: + raise Exception(f"The discovery vertex {vertex.uid} does not have any content data.") + record_type = content_dict.get("record_type") + if record_type in mapping: + content_dict["item"] = mapping[record_type].model_validate(content_dict["item"]) + else: + content_dict["item"] = DiscoveryConfiguration.model_validate(content_dict["item"]) + + return DiscoveryObject.model_validate(content_dict) + + +class CredentialBase(BaseModel): + # Use Any because it might be a str or Secret, but Secret is defined to discover-and_rotation. + user: Optional[Any] = None + dn: Optional[Any] = None + password: Optional[Any] = None + private_key: Optional[Any] = None + private_key_passphrase: Optional[Any] = None + database: Optional[Any] = None + + +class Settings(BaseModel): + + """ + credentials: List of Credentials used to test connections for resources. + default_shared_folder_uid: The default shared folder that should be used when adding records. + include_azure_aadds - Include Azure AD Domain Service. + skip_rules: Do not run the rule engine. + user_map: Map used to map found users to Keeper record UIDs + skip_machines: Do not discovery machines. + skip_databases: Do not discovery databases. + skip_directories: Do not discovery directoires. + skip_cloud_users - Skip cloud users like AWS IAM, or Azure Tenant users. + allow_resource_deletion - Allow discovery to remove resources. + allow_resource_deletion - Allow discovery to remove resources if missing. + allow_user_deletion - Allow discovery to remove users if missing. + resource_deletion_limit - Remove resource if not seen for # seconds; 0 will delete right away. + user_deletion_limit - Remove user right away if not seen for # seconds; 0 will delete right away. + """ + + credentials: List[CredentialBase] = [] + default_shared_folder_uid: Optional[str] = None + include_azure_aadds: bool = False + skip_rules: bool = False + user_map: Optional[List[dict]] = None + skip_machines: bool = False + skip_databases: bool = False + skip_directories: bool = False + skip_cloud_users: bool = False + + # For now, don't delete anything. + allow_resource_deletion: bool = False + allow_user_deletion: bool = False + + resource_deletion_limit: int = 0 + user_deletion_limit: int = 0 + + def set_user_map(self, obj): + if self.user_map is not None: + obj.user_map = self.user_map + + @property + def has_credentials(self): + return len(self.credentials) > 0 + +class DiscoveryDeltaItem(BaseModel): + uid: str + version: int + record_uid: Optional[str] = None + changes: Optional[dict] = None + + @property + def has_record(self) -> bool: + return self.record_uid is not None + +class DiscoveryDelta(BaseModel): + added: List[DiscoveryDeltaItem] = [] + changed: List[DiscoveryDeltaItem] = [] + deleted: List[DiscoveryDeltaItem] = [] + + +class JobItem(BaseModel): + job_id: str + start_ts: int + settings: Settings + end_ts: Optional[int] = None + success: Optional[bool] = None + resource_uid: Optional[str] = None + conversation_id: Optional[str] = None + error: Optional[str] = None + stacktrace: Optional[str] = None + + sync_point: Optional[int] = None + + # Stored chunked, in multiple DATA edges + delta: Optional[DiscoveryDelta] = None + + @property + def duration_sec(self) -> Optional[int]: + if self.end_ts is not None: + return self.end_ts - self.start_ts + return None + + @property + def start_ts_str(self): + return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.start_ts)) + + @property + def end_ts_str(self): + if self.end_ts is not None: + return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.end_ts)) + return "" + + @property + def duration_sec_str(self): + if self.is_running is True: + duration_sec = int(time.time()) - self.start_ts + else: + duration_sec = self.duration_sec + + if duration_sec is not None: + return str(datetime.timedelta(seconds=int(duration_sec))) + else: + return "" + + @property + def is_running(self): + # If no end timestamp, and there is a start timestamp, and the job has not been processed, and there is no + # success is running. + return self.end_ts is None and self.start_ts is not None and self.success is None + + +class JobContent(BaseModel): + active_job_id: Optional[str] = None + job_history: List[JobItem] = [] + + +class PamGraphId(BaseEnum): + PAM = 0 + DISCOVERY_RULES = 10 + DISCOVERY_JOBS = 11 + INFRASTRUCTURE = 12 + SERVICE_LINKS = 13 + + +class RuleTypeEnum(BaseEnum): + ACTION = "action" + SCHEDULE = "schedule" + COMPLEXITY = "complexity" + + +class RuleActionEnum(BaseEnum): + PROMPT = "prompt" + ADD = "add" + IGNORE = "ignore" + + +class Statement(BaseModel): + field: str + operator: str + value: Any + + +class RuleItem(BaseModel): + name: Optional[str] = None + added_ts: Optional[int] = None + rule_id: Optional[str] = None + enabled: bool = True + priority: int = 0 + case_sensitive: bool = True + statement: List[Statement] + + # Do not set this. + # This needs to be here for the RuleEngine. + # The RuleEngine will set this to its self. + engine_rule: Optional[object] = None + + @property + def added_ts_str(self): + return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.added_ts)) + + def search(self, search: str) -> bool: + for item in self.statement: + if search in item.field or search in item.value: + return True + + if search in self.rule_id.lower() or search == self.rule_action.value or search == str(self.priority): + return True + + return False + + def close(self): + try: + if self.engine_rule and hasattr(self.rule_engine, "close"): + self.engine_rule.close() + self.engine_rule = None + del self.engine_rule + except (Exception,): + pass + + def __del__(self): + self.close() + + +class ActionRuleItem(RuleItem): + action: Optional[RuleActionEnum] = RuleActionEnum.PROMPT + shared_folder_uid: Optional[str] = None + admin_uid: Optional[str] = None + + +class ScheduleRuleItem(RuleItem): + tag: str + + +class ComplexityRuleItem(RuleItem): + tag: str + + +class RuleSet(BaseModel): + rules: List[RuleItem] = [] + + @property + def count(self) -> int: + return len(self.rules) + + def __str__(self): + rule_set = [] + for item in self.rules: + rule_set.append(item.model_dump_json()) + + return "[" + ",\n" .join(rule_set) + "]" + + +class ActionRuleSet(RuleSet): + rules: List[ActionRuleItem] = [] + + +class ScheduleRuleSet(RuleSet): + rules: List[ScheduleRuleItem] = [] + + +class ComplexityRuleSet(RuleSet): + rules: List[ComplexityRuleItem] = [] + +class ServiceAcl(BaseModel): + is_service: bool = False + is_task: bool = False + is_iis_pool: bool = False + + def is_used(self): + return self.is_service or self.is_task or self.is_iis_pool + +class PromptActionEnum(BaseEnum): + ADD = "add" + IGNORE = "ignore" + SKIP = "skip" + +class DirectoryInfo(BaseModel): + directory_record_uids: List[str] = [] + directory_user_record_uids: List[str] = [] + + def has_directories(self) -> bool: + return len(self.directory_record_uids) > 0 + + +class NormalizedRecord(BaseModel): + """ + This class attempts to normalize KeeperRecord, TypedRecord, KSM Record into a normalized record. + """ + record_uid: str + record_type: str + title: str + fields: List[RecordField] = [] + note: Optional[str] = None + + def _field(self, field_type, label) -> Optional[RecordField]: + for field in self.fields: + value = field.value + if value is None or len(value) == 0: + continue + if field.label == field_type and value[0].lower() == label.lower(): + return field + return None + + def find_user(self, user): + + from .dag_utils import split_user_and_domain + + res = self._field("login", user) + if res is None: + user, _ = split_user_and_domain(user) + res = self._field("login", user) + + return res + + def find_dn(self, user): + return self._field("distinguishedName", user) + + +class PromptResult(BaseModel): + + # "add" and "ignore" are the only action + action: PromptActionEnum + + # The acl is only needs for pamUser record. + acl: Optional[UserAcl] = None + + # If the discovery object content has been modified, set it here. + content: Optional[DiscoveryObject] = None + + # Existing record that should be the admin. + record_uid: Optional[str] = None + + # Is this is a pamUser and a directory user? + is_directory_user: bool = False + + # Note to include with record + note: Optional[str] = None + + +class BulkRecordAdd(BaseModel): + + # The title of the record. + # This is used for debug reasons. + title: str + + # Record note + note: Optional[str] = None + + # This could be a Commander KeeperRecord, Commander RecordAdd, NormalizedRecord, or KSM Record + record: Any + record_type: str + + # If record_type is a PAM User, is this user the admin of the resource? + admin_uid: Optional[str] = None + + # Normal record UID strings + record_uid: str + parent_record_uid: Optional[str] = None + + # The shared folder UID where the record should be created. + shared_folder_uid: str + + +class BulkRecordConvert(BaseModel): + record_uid: str + parent_record_uid: Optional[str] = None + + # Record note + note: Optional[str] = None + + +class BulkRecordSuccess(BaseModel): + title: str + record_uid: str + + +class BulkRecordFail(BaseModel): + title: str + error: str + + +class BulkProcessResults(BaseModel): + success: List[BulkRecordSuccess] = [] + failure: List[BulkRecordFail] = [] + + @property + def has_failures(self) -> bool: + return len(self.failure) > 0 + + @property + def num_results(self) -> int: + return self.failure_count + self.success_count + + @property + def failure_count(self) -> int: + return len(self.failure) + + @property + def success_count(self) -> int: + return len(self.success) diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_utils.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_utils.py new file mode 100644 index 00000000..cca0e385 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_utils.py @@ -0,0 +1,116 @@ +import os +from typing import List, Optional, Tuple +from .__version__ import __version__ + + +def value_to_boolean(value): + value = str(value) + if value.lower() in ['true', 'yes', 'on', '1']: + return True + elif value.lower() in ['false', 'no', 'off', '0']: + return False + else: + return None + + +def kotlin_bytes(data: bytes): + return [b if b < 128 else b - 256 for b in data] + + +def get_connection(**kwargs): + + """ + This method will return the proper connection based on the params passed in. + + If `ksm` and a KDNRM KSM instance, it will connect using keeper secret manager. + If `params` and a KeeperParam instance, it will connect using Commander. + If the env var `USE_LOCAL_DAG` is True, it will connect using the Local test DAG engine. + + It returns a child instance of the Connection class. + """ + + # if the connection is passed in, return it. + if kwargs.get("connection") is not None: + return kwargs.get("connection") + + vault = kwargs.get("vault") + logger = kwargs.get("logger") + if value_to_boolean(os.environ.get("USE_LOCAL_DAG")): + from ..keeper_dag.connection.local import Connection + conn = Connection(logger=logger) + else: + use_read_protobuf = kwargs.get("use_read_protobuf") + use_write_protobuf = kwargs.get("use_write_protobuf") + + if vault is not None: + from ..keeper_dag.connection.commander import Connection + conn = Connection(vault=vault, + logger=logger, + use_read_protobuf=use_read_protobuf, + use_write_protobuf=use_write_protobuf) + else: + raise ValueError("Must pass 'vault' for Keeper SDK. Found neither.") + return conn + + +def make_agent(text) -> str: + return f"{text}/{__version__}" + + + +def split_user_and_domain(user: str) -> Tuple[Optional[str], Optional[str]]: + + if user is None: + return None, None + + domain = None + + if "\\" in user: + user_parts = user.split("\\", maxsplit=1) + user = user_parts[0] + domain = user_parts[1] + elif "@" in user: + user_parts = user.split("@") + domain = user_parts.pop() + user = "@".join(user_parts) + + return user, domain + + +def user_check_list(user: str, name: Optional[str] = None, source: Optional[str] = None) -> List[str]: + user, domain = split_user_and_domain(user) + user = user.lower() + + # TODO: Add boolean for tasks to include `local users` patterns. + # It appears that for task lists, directory users do not have domains. + # A problem could arise where the customer uses a local user and directory with the same name. + check_list = [user, f".\\{user}"] + if name is not None: + name = name.lower() + check_list += [name, f".\\{name}"] + if source is not None: + source = source.lower() + check_list.append(f"{source[:15]}\\{user}") + check_list.append(f"{user}@{source}") + netbios_parts = source.split(".") + if len(netbios_parts) > 1: + check_list.append(f"{netbios_parts[0][:15]}\\{user}") + check_list.append(f"{user}@{netbios_parts[0]}") + if domain is not None: + domain = domain.lower() + check_list.append(f"{domain[:15]}\\{user}") + check_list.append(f"{user}@{domain}") + domain_parts = domain.split(".") + if len(domain_parts) > 1: + check_list.append(f"{domain_parts[0][:15]}\\{user}") + check_list.append(f"{user}@{domain_parts[0]}") + + return list(set(check_list)) + + +def user_in_lookup(user: str, lookup: dict, name: Optional[str] = None, source: Optional[str] = None) -> bool: + + for check_user in user_check_list(user, name, source): + if check_user in lookup: + return True + return False diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_vertex.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_vertex.py new file mode 100644 index 00000000..467760c9 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_vertex.py @@ -0,0 +1,913 @@ +from typing import Optional, Union, List, Any, Tuple, TYPE_CHECKING + +import pydantic +from .dag_types import EdgeType, RefType +from . import dag_crypto +from .dag_edge import DAGEdge +from .exceptions import DAGVertexException, DAGDeletionException, DAGIllegalEdgeException, DAGKeyException + +if TYPE_CHECKING: + from .dag import DAG + + +class DAGVertex: + + def __init__(self, dag: "DAG", uid: Optional[str] = None, name: Optional[str] = None, + keychain: Optional[bytes] = None, vertex_type: RefType = RefType.GENERAL): + + self.dag = dag + + # If the UID is not set, generate a UID. + if uid is None: + uid = dag_crypto.generate_uid_str() + # Else verify that the UID is valid. The UID should be a 16-byte value that is web-safe base64 serialized. + else: + if len(uid) != 22: + raise ValueError(f"The uid {uid} is not a 22 characters in length.") + try: + b = dag_crypto.urlsafe_str_to_bytes(uid) + if len(b) != 16: + raise ValueError("not 16 bytes") + except Exception: + raise ValueError("The uid does not appear to be web-safe base64 string contains a 16 bytes value.") + + # If the UID is the root UID, make sure the vertex type is not general. + # The root vertex needs to be either PAM_NETWORK or PAM_USER, if not set to PAM_NETWORK. + if uid == self.dag.uid and (vertex_type != RefType.PAM_NETWORK and vertex_type != RefType.PAM_USER): + vertex_type = RefType.PAM_NETWORK + self.vertex_type = vertex_type + + # If the name is not defined, use the UID. Name is not persistent in the DAG. + # If you load the DAG, the web service will not return the name. + if name is None: + name = uid + + self._uid = uid + self._name = name + + # The keychain is a list of keys that can be used. + # The keychain may contain multiple keys, when loading the default graph (graph_id) + # For normal editing, the keychain will contain only one key. + self._keychain = [] + if keychain is not None: + if not isinstance(keychain, list): + keychain = [keychain] + self._keychain += keychain + + # Is the keychain corrupt? + self.corrupt = False + + # These are edges to which vertex own this vertex. This vertex belongs to. So this would + self.edges: list[Optional[DAGEdge]] = [] + self.has_uid = [] + + # Flag indicating that this vertex is active. + # This means this vertex has an active edge connected to another vertex. + self.active = True + + # By default, we will save this vertex; not skip_save. + # If in the process building the graph, it is decided that a vertex should not be saved; this can be set to + # prevent the vertex from being saved. + self._skip_save = False + + def __str__(self): + ret = f"Vertex {self.uid}\n" + ret += f" python instance id: {id(self)}\n" + ret += f" name: {self.name}\n" + ret += f" keychain: {self.keychain}\n" + ret += f" active: {self.active}\n" + ret += f" edges:\n" + for edge in self.edges: + ret += f" * type {self.dag.__class__.EDGE_LABEL.get(edge.edge_type)}" + ret += f", connect to {edge.head_uid}" + ret += f", path {edge.path}, " + ret += f", active: {edge.active}" + ret += f", modified: {edge.modified}" + ret += f", content: {'yes' if edge.content is not None else 'no'}" + ret += f", content type: {type(edge.content)}" + ret += "\n" + return ret + + def __repr__(self): + return f"" + + def debug(self, msg: str, level: int = 0): + self.dag.debug(msg, level=level) + + @property + def name(self) -> str: + """ + Get the name for vertex + + If the name is not defined, the UID will be returned. + The name is not persistent. + If loading a DAG, the name will not be set. + + :return: + """ + if self._name is not None: + return self._name + return self._uid + + @property + def key(self) -> Optional[Union[str, bytes]]: + """ + Get a single key from the keychain. + + :return: + """ + + keychain = self.keychain + if len(keychain) > 0: + return self.keychain[0] + + return None + + @property + def skip_save(self): + return self._skip_save + + @skip_save.setter + def skip_save(self, value): + self._skip_save = value + + for vertex in self.has_vertices(): + vertex._skip_save = value + + def add_to_keychain(self, key: Union[str, bytes]): + """ + Add a key to the keychain + + :param key: A decrypted key bytes or encrypted key str + :return: + """ + if key not in self._keychain: + self._keychain.append(key) + + @property + def keychain(self) -> Optional[List[Union[str, bytes]]]: + """ + Get the keychain for the vertex. + + The key is stored on the edges, however, the key belongs to the vertex. + KEY and ACL edges from this vertex will have the same encrypted key. + It is simpler to store the key on the DAGVertex instance. + + The keychain in an array of keys. + When using graph_id = 0, different graphs that have the same UID will + have different keys. + When decrypting DATA edges, each key in the keychain will be tried. + + If the keychain has not been set, check if any edges exist that require a key. + If there are, then generate a random key. + The load process will populate the key. + If the vertex does not have a key in the keychain, it is because this is a newly + added vertex. + + If there are no edges that require a key, then return None. + """ + + # If the vertex is root, then the keychain will be the key bytes. + if self.dag.get_root == self: + self._keychain = [self.dag.key] + + # If the keychain is empty, generate a key for a specific edge type. + elif len(self._keychain) == 0: + for e in self.edges: + if e.edge_type in [EdgeType.KEY, EdgeType.DATA]: + self._keychain.append(dag_crypto.generate_random_bytes(self.dag.__class__.UID_KEY_BYTES_SIZE)) + break + + return self._keychain + + @keychain.setter + def keychain(self, value: List[Union[str, bytes]]): + """ + Set the key in the vertex. + + The save method will use this key for any KEY/ACL edges. + A key of str type means it is encrypted. + """ + self._keychain = value + + @property + def has_decrypted_keys(self) -> Optional[bool]: + """ + Does the vertex have a decrypted keys? + + If the vertex contains a KEY, ACL or DATA edge and if the key is bytes, then the key is decrypted. + If it is a str type, then it is encrypted. + """ + if len(self._keychain) > 0: + for e in self.edges: + if e.edge_type in [EdgeType.KEY, EdgeType.DATA]: + all_decrypted = True + for key in self._keychain: + if not isinstance(key, bytes): + all_decrypted = False + break + return all_decrypted + return None + + @property + def uid(self): + """ + Get the vertex UID. + + Once set, don't allow it to be changed. + """ + return self._uid + + def get_edge(self, vertex: "DAGVertex", edge_type: EdgeType) -> DAGEdge: + high_edge = None + high_version = -1 + for edge in self.edges: + # Get all the edge point at the same vertex. + # Don't include DATA edges. + if edge.head_uid == vertex.uid and edge.edge_type == edge_type: + if edge.version > high_version: + high_version = edge.version + high_edge = edge + return high_edge + + def get_highest_edge_version(self, head_uid: str) -> Tuple[int, Optional[DAGEdge]]: + """ + Find the highest DAGEdge version of all edge types. + + :param head_uid: + :return: + """ + + high_edge = None + high_version = -1 + for edge in self.edges: + # Get all the edge point at the same vertex. + # Don't include DATA edges. + if edge.head_uid == head_uid: + if edge.version > high_version: + high_edge = edge + high_version = edge.version + return high_version, high_edge + + def edge_count(self, vertex: "DAGVertex", edge_type: EdgeType) -> int: + """ + Get the number of edges between two vertices. + + :param vertex: + :param edge_type: + :return: + """ + count = 0 + for edge in self.edges: + if edge.head_uid == vertex.uid and edge.edge_type == edge_type: + count += 1 + return count + + def edge_by_type(self, vertex: "DAGVertex", edge_type: EdgeType) -> List[DAGEdge]: + edge_list = [] + for edge in self.edges: + if edge.edge_type == edge_type and edge.head_uid == vertex.uid: + edge_list.append(edge) + return edge_list + + @property + def has_data(self) -> bool: + + """ + Does this vertex contain a DATA edge? + + :return: True if vertex has a DATA edge. + """ + + for item in self.edges: + if item.edge_type == EdgeType.DATA: + return True + return False + + def get_data(self, index: Optional[int] = None) -> Optional[DAGEdge]: + """ + Get data edge + + If the index is None or 0, the latest data edge will be returned. + A positive and negative, non-zero, index will return the same data. + It will be the absolute value of the index from the latest data. + This means the 1 or -1 will return the prior data. + + If there is no data, None is returned. + + :param index: + :return: + """ + + data_list = self.edge_by_type(self, EdgeType.DATA) + data_count = len(data_list) + if data_count == 0: + return None + + # If the index is None, get the latest. + if index is None or index == 0: + index = -1 + # Since -1 is the current, switch index to a negative number and subtract one more. + # For example, 1 means prior, -1 would be the latest, so we need to subtract one to get -2. + elif index > 0: + index *= -1 + index -= 1 + # If already a negative index, just subtract one. + else: + index -= 1 + + try: + data = data_list[index] + except IndexError: + raise ValueError(f"The index is not valid. Currently there are {data_count} data edges") + + return data + + def add_data(self, + content: Any, + is_encrypted: bool = False, + is_serialized: bool = False, + path: Optional[str] = None, + modified: bool = True, + from_load: bool = False, + needs_encryption: bool = True): + + """ + Add a DATA edge to the vertex. + + :param content: The content to store in the DATA edge. + :param is_encrypted: Is the content encrypted? + :param is_serialized: Is the content base64 serialized? + :param path: Simple string tag to identify the edge. + :param modified: Does this modify the content? + By default, adding a DATA edge will flag that the edge has been modified. + If loading, modified will be set to False. + :param from_load: This call is being performed the load() method. + Do not validate adding data. + :param needs_encryption: Default is True. + Does the content need to be encrypted? + """ + + self.debug(f"connect {self.uid} to DATA edge", level=1) + + # Are we trying to add DATA to a deleted vertex? + + if not self.active: + # If deleted, there will not be a KEY to decrypt the data. + # Throw an exception if not from the loading method. + if not from_load: + raise DAGDeletionException("This vertex is not active. Cannot add DATA edge.") + # If from loading, do not add and do not throw an exception. + return + + # Make sure the vertex belongs before auto saving. If it does not belong, it's just an orphan right now. + # This only is checked if using this module is used to create the graph. + if self.belongs_to_a_vertex is False and from_load is False: + raise DAGVertexException(f"Before adding data, connect this vertex {self.uid} to another vertex.") + + # Make sure that we have a KEY. + # Allow a DATA edge to be connected to the root vertex, which will not have a KEY edge. + # Or if we are loading, allow out of sync edges. + + if needs_encryption: + found_key_edge = self.dag.get_root == self or from_load is True + if found_key_edge is False: + for edge in self.edges: + if edge.edge_type == EdgeType.KEY: + found_key_edge = True + if found_key_edge is False: + raise DAGKeyException(f"Cannot add DATA edge without a KEY edge for vertex {self.uid}.") + + # Get the prior data, set the version and inactive the prior data. + version = 0 + prior_data = self.get_data() + if prior_data is not None: + version = prior_data.version + 1 + prior_data.active = False + + # Check if DATA has already been created/modified per this session. + # If it has, the prior will be overwritten, no sense on saving this edge. + # If warning is enabled, print a debug message and the stacktrace to we what added the DATA. + if self.dag.dedup_edge and prior_data.modified: + prior_data.skip_on_save = True + if self.dag.dedup_edge_warning: + self.dag.debug("DATA edge added multiple times for session. stacktrace on what did it follows ...") + self.dag.debug_stacktrace() + + # The tail UID is the UID of the vertex. Since data loops back to the vertex, the head UID is the same. + self.edges.append( + DAGEdge( + vertex=self, + edge_type=EdgeType.DATA, + head_uid=self.uid, + version=version, + content=content, + path=path, + modified=modified, + is_serialized=is_serialized, + is_encrypted=is_encrypted, + needs_encryption=needs_encryption + ) + ) + + # If using a history level, we want to remove edges if we exceed the history level. + # The history level is per edge type. + # It's FIFO, so we will remove the first edge type if we exceed the history level. + if self.dag.history_level > 0: + data_count = self.data_count() + while data_count > self.dag.history_level: + for index in range(0, len(self.edges) - 1): + if self.edges[index].edge_type == EdgeType.DATA: + del self.edges[index] + data_count -= 1 + break + + self.dag.do_auto_save() + + def data_count(self): + return self.edge_count(self, EdgeType.DATA) + + def data_delete(self): + + # Get the DATA edge. + # It will be a reference to itself. + data_edge = self.get_edge(self, EdgeType.DATA) + if data_edge is None: + self.debug("cannot delete the data, no data edge exists.") + + data_edge.active = False + + self.belongs_to( + vertex=self, + edge_type=EdgeType.DELETION + ) + self.debug(f"deleted data edge for {self.uid}") + + @property + def latest_data_version(self): + version = -1 + for edge in self.edges: + if edge.edge_type == EdgeType.DATA and edge.version > version: + version = edge.version + return version + + @property + def content(self) -> Optional[Union[str, bytes]]: + """ + Get the content of the active DATA edge. + + If the content is a str, then the content is encrypted. + """ + data_edge = self.get_data() + if data_edge is None: + return None + return data_edge.content + + @property + def content_as_dict(self) -> Optional[dict]: + """ + Get the content from the active DATA edge as a dictionary. + :return: Content as a dictionary. + """ + data_edge = self.get_data() + if data_edge is None: + return None + return data_edge.content_as_dict + + @property + def content_as_str(self) -> Optional[str]: + """ + Get the content from the active DATA edge as a str. + :return: Content as a str. + """ + + data_edge = self.get_data() + if data_edge is None: + return None + return data_edge.content_as_str + + def content_as_object(self, + meta_class: pydantic._internal._model_construction.ModelMetaclass) -> Optional[pydantic.BaseModel]: + """ + Get the content as a pydantic based object. + + :param meta_class: The class to return + :return: + """ + data_edge = self.get_data() + if data_edge is None: + return None + + return data_edge.content_as_object(meta_class) + + @property + def has_key(self) -> bool: + + """ + Does this vertex contain any KEY or ACL edges? + + :return: True if vertex has a KEY or ACL edge. + """ + + for item in self.edges: + if item.edge_type == EdgeType.KEY: + return True + return False + + def belongs_to(self, + vertex: "DAGVertex", + edge_type: EdgeType, + content: Optional[Any] = None, + is_encrypted: bool = False, + path: Optional[str] = None, + modified: bool = True, + from_load: bool = False): + + """ + Connect a vertex to another vertex (as the owner). + + This will create an edge between this vertex and the passed in vertex. + The passed in vertex will own this vertex. + + If the edge_type is a KEY or ACL, data will be treated as a key. If a DATA edge already exists, the + edge_type will be changed to a KEY, if not a KEY or ACL edge_type. + + :param vertex: The vertex has this vertex. + :param edge_type: The edge type that connects the two vertices. + :param content: Data to store as the edges content. + :param is_encrypted: Is the content encrypted? + :param path: Text tag for the edge. + :param modified: Does adding this edge modify the stored DAG? + :param from_load: Is being connected from load() method? + :return: + """ + + self.debug(f"connect {self.uid} to {vertex.uid} with edge type {edge_type.value}", level=1) + + if vertex is None: + raise ValueError("Vertex is blank.") + if self.uid == self.dag.uid and not (edge_type == EdgeType.DATA or edge_type == EdgeType.DELETION): + if not from_load: + raise DAGIllegalEdgeException(f"Cannot create edge to self for edge type {edge_type}.") + self.dag.debug(f"vertex {self.uid} , the root vertex, " + f"attempted to create '{edge_type.value}' edge to self, skipping.") + return + + # Cannot make an edge to the same vertex, unless the edge type is a DELETION. + # Normally an edge to self is a DATA type, use add_data for that. + # A DELETION edge to self is allowed. + # Just means the DATA edge is being deleted. + if self.uid == vertex.uid and not (edge_type == EdgeType.DATA or edge_type == EdgeType.DELETION): + if not from_load: + raise DAGIllegalEdgeException(f"Cannot create edge to self for edge type {edge_type}.") + self.dag.debug(f"vertex {self.uid} attempted to make '{edge_type.value}' to self, skipping.") + return + + # Figure out what version of the edge we are. + + version, version_edge = self.get_highest_edge_version(head_uid=vertex.uid) + + # If the new edge is not DELETION + if edge_type != EdgeType.DELETION: + + # Find the current active edge for this edge type to make it inactive. + current_edge_by_type = self.get_edge(vertex, edge_type) + if current_edge_by_type is not None: + current_edge_by_type.active = False + + # Check if edge has already been created/modified per this session. + # If it has, the prior will be overwritten, no sense on saving this edge. + # If warning is enabled, print a debug message and the stacktrace to we what added the DATA. + if self.dag.dedup_edge and current_edge_by_type.modified: + current_edge_by_type.skip_on_save = True + if self.dag.dedup_edge_warning: + self.dag.debug(f"{edge_type.value.upper()} edge added multiple times for session. " + "stacktrace on what did it follows ...") + self.dag.debug_stacktrace() + + # If we are adding a non-DELETION edge, it will inactivate the DELETION edge. + highest_deletion_edge = self.get_edge(vertex, EdgeType.DELETION) + if highest_deletion_edge is not None: + highest_deletion_edge.active = False + + # For this purpose, only DATA edge are allow to set the is_encrypted flag. + if edge_type != EdgeType.DATA: + is_encrypted = False + + # Should we activate the vertex again? + if not self.active: + + # If the vertex is already inactive, and we are trying to delete, return. + if edge_type == EdgeType.DELETION: + return + + if self.dag.dedup_edge and version_edge.modified: + version_edge.skip_on_save = True + if self.dag.dedup_edge_warning: + self.dag.debug("edge was deleted in session, will not save DELETION edge") + self.dag.debug_stacktrace() + else: + self.dag.debug(f"vertex {self.uid} was inactive; reactivating vertex.") + self.active = True + + # Create and append a new DAGEdge instance. + # Disable the auto saving after the content is changed since the edge has not been appended yet. + # Once the edge is created, disable blocking auto save for content changes. + edge = DAGEdge( + vertex=self, + edge_type=edge_type, + head_uid=vertex.uid, + version=version + 1, + block_content_auto_save=True, + content=content, + is_encrypted=is_encrypted, + path=path, + modified=modified + ) + edge.block_content_auto_save = False + + self.edges.append(edge) + if self.uid not in vertex.has_uid: + vertex.has_uid.append(self.uid) + + self.dag.do_auto_save() + + def belongs_to_root(self, + edge_type: EdgeType, + path: Optional[str] = None): + + """ + Connect the vertex to the root vertex. + + :param edge_type: The type of edge to use for the connection. + :param path: Short tag for this edge. + :return: + """ + + self.debug(f"connect {self.uid} to root", level=1) + + if self.uid == self.dag.uid: + raise DAGIllegalEdgeException("Cannot create edge to self.") + + if not self.active: + raise DAGDeletionException("This vertex is not active. Cannot connect to root.") + + # We are adding the root, we can enable auto save now. + # We can get the correct stream id with an edge to the root vertex. + self.belongs_to(self.dag.get_root, edge_type=edge_type, path=path) + + self.dag.allow_auto_save = True + self.dag.do_auto_save() + + def has_vertices(self, edge_type: Optional[EdgeType] = None, allow_inactive: bool = False, + allow_self_ref: bool = False) -> List["DAGVertex"]: + + """ + Get a list of vertices that belong to this vertex. + :return: List of DAGVertex + """ + + vertices = [] + for uid in self.has_uid: + + # This will remove DATA and DATA that have changed to DELETION edges. + # Prevent looping. + if uid == self.uid and allow_self_ref is False: + continue + + vertex = self.dag.get_vertex(uid) + if edge_type is not None: + edge = vertex.get_edge(self, edge_type=edge_type) + if edge is not None: + vertices.append(vertex) + + # If no edge type was specified, do not return DATA and DELETION. + # Also do not include vertices that are inactive by default. + elif edge_type != EdgeType.DATA and edge_type != EdgeType.DELETION: + if vertex.active is True or allow_inactive is True: + vertices.append(vertex) + + return vertices + + def has(self, vertex: "DAGVertex", edge_type: Optional[EdgeType] = None) -> bool: + + """ + Does this vertex have the passed in vertex? + + :return: True if request vertex belongs to this vertex. + False if it does not. + """ + + vertices = self.has_vertices(edge_type=edge_type) + return vertex in vertices + + def belongs_to_vertices(self) -> List["DAGVertex"]: + """ + Get a list of vertices that this vertex belongs to + :return: + """ + + vertices = [] + for edge in self.edges: + # If the edge is not a DATA or DELETION type, and the edge is the highest version/active + if edge.edge_type != EdgeType.DATA and edge.edge_type != EdgeType.DELETION and edge.active is True: + + # The head will point at the remote vertex. + # If it is active, and not already in the list, add it to the list of vertices this vertex belongs to. + vertex = self.dag.get_vertex(edge.head_uid) + if vertex.active is True and vertex not in vertices: + vertices.append(vertex) + return vertices + + @property + def belongs_to_a_vertex(self) -> bool: + """ + Does this vertex belong to another vertex? + :return: + """ + + # If this is the root vertex, return True. + # Where this is being called should handle operations involving the root vertex. + if self.dag.get_root == self: + return True + + return len(self.belongs_to_vertices()) > 0 + + def disconnect_from(self, vertex: "DAGVertex", path: Optional[str] = None): + + """ + Disconnect this vertex from another vertex. + + This will add a DELETION edge between two vertices. + If the vertex no longer belongs to another vertex, the vertex will be deleted. + + :param vertex: The vertex this vertex belongs to + :param path: an Optional path for the DELETION edge. + :return: + """ + + if vertex is None: + raise ValueError("Vertex is blank.") + + # Flag all the edges as inactive. + for edge in self.edges: + if edge.head_uid == vertex.uid and edge.edge_type: + edge.active = False + + # Add the DELETION edge + self.belongs_to( + vertex=vertex, + edge_type=EdgeType.DELETION, + path=path + ) + + # If all the KEY edges are inactive now, the DATA edge needs to be made inactive. + # There is no longer a KEY edge to decrypt the DATA. + has_active_key_edge = False + for edge in self.edges: + if edge.edge_type == EdgeType.KEY and edge.active is True: + has_active_key_edge = True + break + if not has_active_key_edge: + for edge in self.edges: + if edge.edge_type == EdgeType.DATA: + edge.active = False + + if not self.belongs_to_a_vertex: + self.debug(f"vertex {self.uid} is now not active", level=1) + self.active = False + + def delete(self, ignore_vertex: Optional["DAGVertex"] = None): + + """ + Delete a vertex + + Deleting a vertex will inactivate the vertex. + It will also inactivate any vertices, and their edges, that belong to the vertex. + It will not inactivate a vertex that belongs to multiple vertices. + :return: + """ + + def _delete(vertex, prior_vertex): + + # Do not delete the root vertex + if vertex.uid == self.dag.uid: + self.debug(f" * vertex is root, cannot delete root", level=2) + return + + self.debug(f"> checking vertex {vertex.uid}") + + # Should we ignore a vertex? + # If deleting an edge, we want to ignore the vertex that owns the edge. + # This prevents circular calls. + if ignore_vertex is not None and vertex.uid == ignore_vertex.uid: + return + + # Get a list of vertices that belong to this vertex (v) + has_v = vertex.has_vertices() + + if len(has_v) > 0: + self.debug(f" * vertex has {len(has_v)} vertices that belong to it.", level=2) + for v in has_v: + self.debug(f" checking {v.uid}") + _delete(v, vertex) + else: + self.debug(f" * vertex {vertex.uid} has NO vertices.", level=2) + + for e in list(vertex.edges): + if e.edge_type != EdgeType.DATA and (prior_vertex is None or e.head_uid == prior_vertex.uid): + e.delete() + if vertex.belongs_to_a_vertex is False: + self.debug(f" * inactive vertex {vertex.uid}") + vertex.active = False + + self.debug(f"DELETING vertex {self.uid}", level=3) + + # Perform the DELETION edges save in one batch. + # Get the current allowed auto save state, and disable auto save. + current_allow_auto_save = self.dag.allow_auto_save + self.dag.allow_auto_save = False + + _delete(self, None) + + # Restore the allow auto save and trigger auto save() + self.dag.allow_auto_save = current_allow_auto_save + self.dag.do_auto_save() + + def walk_down_path(self, path: Union[str, List[str]]) -> Optional["DAGVertex"]: + + """ + Walk the vertices using the path and return the vertex starting at this vertex. + + :param path: An array of path string, or string where the path is joined with a "/" (i.e., think URL) + :return: DAGVertex is the path completes, None is failure. + """ + + self.debug(f"walking path in vertex {self.uid}", level=2) + + # If the path is str, break it into an array. Get rid of leading / + if isinstance(path, str): + self.debug("path is str, break into array", level=2) + if path.startswith("/"): + path = path[1:] + path = path.split("/") + + # Unshift the path + + current_path = path[0] + path = path[1:] + self.debug(f"current path: {current_path}", level=2) + self.debug(f"path left: {path}", level=2) + + # Check the DATA edges. + # If a DATA edge has the current path, return this vertex. + for edge in self.edges: + if edge.edge_type != EdgeType.DATA: + continue + if edge.path == current_path: + return self + + # Check the vertices that belong to this vertex for edges going to this vertex and the path matches. + for vertex in self.has_vertices(): + self.debug(f"vertex {self.uid} has {vertex.uid}", level=2) + for edge in vertex.edges: + # If the edge matches the current path, the head of the edge is this vertex, a route exists. + if edge.path == current_path and edge.head_uid == self.uid: + # If there is no path left, this is our vertex + if len(path) == 0: + return vertex + # If there is still more path, call vertex to walk more of the path. + else: + return vertex.walk_down_path(path) + return None + + def get_paths(self) -> List[str]: + """ + Get paths from this vertex to vertex owned by this vertex. + :return: List of string paths + """ + + paths = [] + for vertex in self.has_vertices(): + for edge in vertex.edges: + if edge.path is None or edge.path == "": + continue + paths.append(edge.path) + + return paths + + def clean_edges(self): + """ + Recursively clean edges and break circular references. + + This method clears all edge lists and reference tracking to help + Python's garbage collector clean up circular references between + DAG, DAGVertex, and DAGEdge objects. + """ + # Recursively clean child vertices first + for vertex in self.has_vertices(): + vertex.clean_edges() + + # Clear all reference lists + self.edges.clear() + self.has_uid.clear() diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/exceptions.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/exceptions.py new file mode 100644 index 00000000..d1ce951c --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/exceptions.py @@ -0,0 +1,80 @@ +from __future__ import annotations +from typing import Any, Optional + + +class DAGException(Exception): + + def __init__(self, msg: Any, uid: Optional[str] = None): + if not isinstance(msg, str): + msg = str(msg) + + self.msg = msg + self.uid = uid + + super().__init__(self.msg) + + def __str__(self): + return self.msg + + def __repr__(self): + return self.msg + + +class DAGKeyIsEncryptedException(DAGException): + pass + + +class DAGDataEdgeNotFoundException(DAGException): + pass + + +class DAGDeletionException(DAGException): + pass + + +class DAGConfirmException(DAGException): + pass + + +class DAGPathException(DAGException): + pass + + +class DAGVertexAlreadyExistsException(DAGException): + pass + + +class DAGContentException(DAGException): + pass + + +class DAGDefaultGraphException(DAGException): + pass + + +class DAGIllegalEdgeException(DAGException): + pass + + +class DAGKeyException(DAGException): + pass + + +class DAGDataException(DAGException): + pass + + +class DAGVertexException(DAGException): + pass + + +class DAGEdgeException(DAGException): + pass + + +class DAGCorruptException(DAGException): + pass + + +class DAGConnectionException(DAGException): + pass diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/infrastructure.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/infrastructure.py new file mode 100644 index 00000000..8e7e0e42 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/infrastructure.py @@ -0,0 +1,339 @@ + +import importlib +import logging +import os +import time +from typing import Any, Optional + +from .dag import DAG +from .dag_types import EdgeType, PamGraphId +from .dag_vertex import DAGVertex +from .dag_utils import get_connection, make_agent +from .exceptions import DAGVertexException +from ... import utils + + +logger = logging.getLogger() + + +class Infrastructure: + + """ + Create a graph of the infrastructure. + + The first run will create a full graph since the vertices do not exist. + Further discovery run will only show vertices that ... + * do not have vaults records. + * the data has changed. + * the ACL has changed. + + """ + + KEY_PATH = "infrastructure" + DELTA_PATH = "delta" + ADMIN_PATH = "ADMINS" + USER_PATH = "USERS" + + def __init__(self, record: Any, logger: Optional[Any] = None, history_level: int = 0, + debug_level: int = 0, fail_on_corrupt: bool = True, log_prefix: str = "GS Infrastructure", + save_batch_count: int = 200, agent: Optional[str] = None, + **kwargs): + + # This will either be a KSM Record, or Commander KeeperRecord + self.record = record + self._dag = None + if logger is None: + logger = logging.getLogger() + self.logger = logger + self.log_prefix = log_prefix + self.history_level = history_level + self.debug_level = debug_level + self.fail_on_corrupt = fail_on_corrupt + self.save_batch_count = save_batch_count + + self.auto_save = False + self.delta_graph = True + self.last_sync_point = -1 + + self.agent = make_agent("infra") + if agent is not None: + self.agent += "; " + agent + + self.conn = get_connection(logger=logger, **kwargs) + + @property + def dag(self) -> DAG: + if self._dag is None: + + self.logger.debug(f"loading the dag graph {PamGraphId.INFRASTRUCTURE.value}") + self.logger.debug(f"setting graph save batch count to {self.save_batch_count}") + + self._dag = DAG(conn=self.conn, + record=self.record, + graph_id=PamGraphId.INFRASTRUCTURE, + auto_save=self.auto_save, + logger=self.logger, + history_level=self.history_level, + debug_level=self.debug_level, + name="Discovery Infrastructure", + fail_on_corrupt=self.fail_on_corrupt, + log_prefix=self.log_prefix, + save_batch_count=self.save_batch_count, + agent=self.agent) + + return self._dag + + @property + def has_discovery_data(self) -> bool: + # Does the graph array have any vertices? + if not self.dag.has_graph: + return False + + # If we at least have the root, does is have the configuration? + if not self.get_root.has_vertices(): + return False + + return True + + @property + def get_root(self) -> DAGVertex: + return self.dag.get_root + + @property + def get_configuration(self) -> DAGVertex: + try: + configuration = self.get_root.has_vertices()[0] + except (Exception,): + raise DAGVertexException("Could not find the configuration vertex for the infrastructure graph.") + return configuration + + @property + def sync_point(self): + return self._dag.load(sync_point=0) + + def load(self, sync_point: int = 0): + ts = time.time() + res = self.dag.load(sync_point=sync_point) or 0 + self.logger.debug(f"infrastructure took {time.time()-ts} secs to load") + return res + + def close(self): + """ + Clean up resources held by this Infrastructure instance. + Releases the DAG instance and connection to prevent memory leaks. + """ + if self._dag is not None: + self._dag = None + self.conn = None + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures cleanup.""" + self.close() + return False + + def __del__(self): + self.close() + + def save(self, delta_graph: Optional[bool] = None): + if delta_graph is None: + delta_graph = self.delta_graph + + self.logger.debug(f"current sync point {self.last_sync_point}") + if delta_graph: + self.logger.debug("saving delta graph of the infrastructure") + ts = time.time() + self._dag.save(delta_graph=delta_graph) + self.logger.debug(f"infrastructure took {time.time()-ts} secs to save") + + def to_dot(self, graph_format: str = "svg", show_hex_uid: bool = False, + show_version: bool = True, show_only_active_vertices: bool = False, + show_only_active_edges: bool = False, sync_point: int = None, graph_type: str = "dot"): + + try: + mod = importlib.import_module("graphviz") + except ImportError: + raise Exception("Cannot to_dot(), graphviz module is not installed.") + + dot = getattr(mod, "Digraph")(comment=f"DAG for Discovery", format=graph_format) + + if sync_point is None: + sync_point = self.last_sync_point + + self.logger.debug(f"generating infrastructure dot starting at sync point {sync_point}") + + self.dag.load(sync_point=sync_point) + + count = 0 + if len(self.dag.get_root.has_vertices()) > 0: + config_vertex = self.dag.get_root.has_vertices()[0] + count = len(config_vertex.has_vertices()) + + if graph_type == "dot": + dot.attr(rankdir='RL') + rank_sep = 10 + if count > 10: + rank_sep += int(count * 0.10) + dot.attr(ranksep=str(rank_sep)) + elif graph_type == "twopi": + rank_sep = 20 + if count > 20: + rank_sep += int(count * 0.10) + + dot.attr(layout="twopi") + dot.attr(ranksep=str(rank_sep)) + dot.attr(ratio="auto") + else: + dot.attr(layout=graph_type) + dot.attr(ranksep=10) + + for v in self.dag.all_vertices: + if show_only_active_vertices is True and v.active is False: + continue + + shape = "ellipse" + fillcolor = "white" + color = "black" + + if not v.corrupt: + + if not v.active: + fillcolor = "grey" + + record_type = None + record_uid = None + name = v.name + source = None + try: + data = v.content_as_dict + record_type = data.get("record_type") + record_uid = data.get("record_uid") + name = data.get("name") + item = data.get("item") + if item is not None: + if item.get("managed", False) is True: + shape = "box" + source = item.get("source") + if record_uid is not None: + fillcolor = "#AFFFAF" + if data.get("ignore_object", False): + fillcolor = "#DFDFFF" + except (Exception,): + pass + + label = f"uid={v.uid}" + if record_type is not None: + label += f"\\nrt={record_type}" + if name is not None and name != v.uid: + name = name.replace("\\", "\\\\") + label += f"\\nname={name}" + if source is not None: + label += f"\\nsource={source}" + if record_uid is not None: + label += f"\\nruid={record_uid}" + if show_hex_uid: + label += f"\\nhex={utils.base64_url_decode(v.uid).hex()}" + if v.uid == self.dag.get_root.uid: + fillcolor = "gold" + label += f"\\nsp={sync_point}" + + tooltip = f"ACTIVE={v.active}\\n\\n" + try: + content = v.content_as_dict + for k, val in content.items(): + if k == "item": + continue + if isinstance(val, str): + val = val.replace("\\", "\\\\") + tooltip += f"{k}={val}\\n" + + item = content.get("item") + if item is not None: + tooltip += f"------------------\\n" + for k, val in item.items(): + if isinstance(val, str): + val = val.replace("\\", "\\\\") + tooltip += f"{k}={val}\\n" + except Exception as err: + tooltip += str(err) + else: + fillcolor = "red" + label = f"{v.uid} (CORRUPT)" + tooltip = "CORRUPT" + + dot.node(v.uid, label, color=color, fillcolor=fillcolor, style="filled", shape=shape, tooltip=tooltip) + + head_uids = [] + for edge in v.edges: + + # Don't show edges that reference self, DATA and data that has been DELETION + if edge.head_uid == v.uid: + continue + + if edge.head_uid not in head_uids: + head_uids.append(edge.head_uid) + + def _render_edge(e): + + edge_color = "grey" + style = "solid" + + if e.corrupt is False: + + # To reduce the number of edges, only show the active edges + if e.active is True: + edge_color = "black" + style = "bold" + elif show_only_active_edges: + return + + # If the vertex is not active, gray out the DATA edge + if e.edge_type == EdgeType.DATA and v.active is False: + edge_color = "grey" + + if e.edge_type == EdgeType.DELETION: + style = "dotted" + + edge_tip = "" + if e.edge_type == EdgeType.ACL and v.active is True: + edge_content = e.content_as_dict + for key, value in content.items(): + edge_tip += f"{key}={value}\\n" + if edge_content.get("is_admin") is True: + edge_color = "red" + + edge_label = DAG.EDGE_LABEL.get(e.edge_type) + if edge_label is None: + edge_label = "UNK" + if e.path is not None and e.path != "": + edge_label += f"\\npath={e.path}" + if show_version: + edge_label += f"\\nv={e.version}" + else: + edge_label = f"{e.edge_type.value} (CORRUPT)" + edge_color = "red" + edge_tip = "CORRUPT" + + # tail, head (arrow side), label, ... + dot.edge(v.uid, e.head_uid, edge_label, style=style, fontcolor=edge_color, color=edge_color, + tooltip=edge_tip) + + for head_uid in head_uids: + version, edge = v.get_highest_edge_version(head_uid) + _render_edge(edge) + + data_edge = v.get_data() + if data_edge is not None: + _render_edge(data_edge) + + return dot + + def render(self, name: str, **kwargs): + + output_name = os.environ.get("GRAPH_DIR", os.environ.get("HOME", os.environ.get("PROFILENAME", "."))) + output_name = os.path.join(output_name, name) + dot = self.to_dot(**kwargs) + dot.render(output_name) diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/jobs.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/jobs.py new file mode 100644 index 00000000..44622b5f --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/jobs.py @@ -0,0 +1,525 @@ + +import base64 +import copy +import logging +import os +from time import time +from typing import Any, Optional + + +from ..keeper_dag.dag_utils import get_connection, make_agent +from ..keeper_dag.dag import DAG, EdgeType +from ..keeper_dag.dag_types import PamGraphId, JobContent, Settings, JobItem, DiscoveryDelta +import importlib +from typing import Any, Optional, List, TYPE_CHECKING + +if TYPE_CHECKING: + from ..keeper_dag.dag import DAGVertex + +class Jobs: + + KEY_PATH = "jobs" + + # Break up the serialized delta. + # This is so it fits in data edge with is limit to a MySQL BLOB, 65k + # The content will be encrypted and base64, so this delta size needs to take that in account. + DELTA_SIZE = 48_000 + + # Only keep history for the last 30 runs. + HISTORY_LIMIT = 30 + + # Limit stacktrace characters + STACKTRACE_LIMIT = 20_000 + + # Limit the length of the error message in JobContent + ERROR_LIMIT = 10_000 + SUMMARY_ERROR_LIMIT = 40 + + def __init__(self, record: Any, logger: Optional[Any] = None, debug_level: int = 0, fail_on_corrupt: bool = True, + log_prefix: str = "GS Jobs", save_batch_count: int = 200, agent: Optional[str] = None, + **kwargs): + + self.conn = get_connection(logger=logger, **kwargs) + + # This will either be a KSM Record, or Commander KeeperRecord + self.record = record + self._dag = None + if logger is None: + logger = logging.getLogger() + logger.propagate = False + self.logger = logger + self.log_prefix = log_prefix + self.debug_level = debug_level + self.fail_on_corrupt = fail_on_corrupt + self.save_batch_count = save_batch_count + + self.agent = make_agent("jobs") + if agent is not None: + self.agent += "; " + agent + + @property + def dag(self) -> DAG: + if self._dag is None: + + self._dag = DAG(conn=self.conn, + record=self.record, + # endpoint=PamEndpoints.DISCOVERY_JOBS, + graph_id=PamGraphId.DISCOVERY_JOBS, + auto_save=False, + logger=self.logger, + debug_level=self.debug_level, + name="Discovery Jobs", + fail_on_corrupt=self.fail_on_corrupt, + log_prefix=self.log_prefix, + save_batch_count=self.save_batch_count, + agent=self.agent) + + ts = time() + self._dag.load() + self.logger.debug(f"jobs took {time() - ts} secs to load") + + # Has the status been initialized? + if not self._dag.has_graph: + self._dag.allow_auto_save = False + status = self._dag.add_vertex() + status.belongs_to_root( + EdgeType.KEY, + path=Jobs.KEY_PATH) + status.add_data( + content=JobContent( + active_job_id=None, + job_history=[] + ), + ) + self._dag.save() + return self._dag + + def close(self): + """ + Clean up resources held by this Jobs instance. + Releases the DAG instance and connection to prevent memory leaks. + """ + if self._dag is not None: + self._dag = None + self.conn = None + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures cleanup.""" + self.close() + return False + + def __del__(self): + self.close() + + @property + def data_path(self): + return f"/{Jobs.KEY_PATH}" + + def get_jobs(self): + + self.logger.debug("loading discovery jobs from DAG") + + vertex = self.dag.walk_down_path(self.data_path) + current_dict = vertex.content_as_dict + + if current_dict is None: + self.logger.debug(" there is no job content, creating empty job content") + vertex.add_data( + content=JobContent( + active_job_id=None, + job_history=[] + ), + ) + current_dict = vertex.content_as_dict + + # For job_history, settings will be blank/defaults. + # This make sure setting is set to blank, if it was None. + for job in current_dict.get("job_history", []): + job["settings"] = {} + + return JobContent.model_validate(current_dict) + + def _chunk_delta_data(self, job_vertex: "DAGVertex", delta: DiscoveryDelta): + + # From the job vertex we want to create vertices to hold the delta information. + # Break them up based on the DELTA_SIZE. + # Each DATA edge will contain part of the content. + # Each delta vertex has a path so we know the order on how to re-assemble them. + delta_content = delta.model_dump_json() + self.logger.debug(f"job delta content is {len(delta_content)} bytes, chunk size is {Jobs.DELTA_SIZE} bytes") + + existing_delta_vertices = job_vertex.has_vertices() + if len(existing_delta_vertices) > 0: + self.logger.debug(f"job delta exists, remove old delta") + for delta_vertex in existing_delta_vertices: + delta_vertex.delete() + + chunk_num = 0 + while delta_content != "": + path = str(chunk_num) + + chunk = delta_content[:Jobs.DELTA_SIZE] + delta_content = delta_content[Jobs.DELTA_SIZE:] + + new_vertex = job_vertex.dag.add_vertex() + new_vertex.belongs_to(job_vertex, edge_type=EdgeType.KEY, path=path) + new_vertex.add_data(chunk) + + self.logger.debug(f" * vertex {new_vertex.uid}, chunk {chunk_num}, {len(chunk)} bytes") + + chunk_num += 1 + + def set_jobs(self, jobs: JobContent): + + self.logger.debug("saving discovery jobs to DAG") + + # Get the main vertex. + jobs_vertex = self.dag.walk_down_path(self.data_path) + + clean_jobs = [] + for job in jobs.job_history: + + # Does the job vertex exist? + # If not created it. + job_vertex = jobs_vertex.walk_down_path(job.job_id) + if job_vertex is None: + self.logger.debug(f" create a job vertex for {job.job_id}") + job_vertex = jobs_vertex.dag.add_vertex() + job_vertex.belongs_to(jobs_vertex, edge_type=EdgeType.KEY, path=job.job_id) + else: + self.logger.debug(f" job vertex for {job.job_id} exists") + + # If the job has delta data, chunk save it and remove it from the JobItem. + # If is not store in the history anymore. + if job.delta is not None: + self.logger.debug(" included discovery delta") + self._chunk_delta_data(job_vertex, job.delta) + job.delta = None + else: + self.logger.debug(" did not include discovery delta") + + # In-case the stacktrace is too large, take only a limit about of characters from the end. + if job.stacktrace is not None: + self.logger.debug(f"stacktrace is {len(job.stacktrace)} characters") + if len(job.stacktrace) > Jobs.STACKTRACE_LIMIT: + self.logger.debug(f" stacktrace too long; truncate to {Jobs.STACKTRACE_LIMIT} characters") + start = len(job.stacktrace) - Jobs.STACKTRACE_LIMIT + job.stacktrace = job.stacktrace[start:] + + # Reduce the error message, if set, and remove stacktrace. + if job.error is not None: + self.logger.debug(f"error is {len(job.error)} characters") + if len(job.error) > Jobs.ERROR_LIMIT: + self.logger.debug(f" error too long; truncate to {Jobs.ERROR_LIMIT} characters") + job.error = job.error[:Jobs.ERROR_LIMIT] + "..." + + # Store the full JobItem (minus delta) on the job vertex. + job_vertex.add_data( + content=job + ) + + # Reduce the error message, if set, and remove stacktrace. + if job.error is not None and len(job.error) > Jobs.SUMMARY_ERROR_LIMIT: + job.error = job.error[:Jobs.SUMMARY_ERROR_LIMIT] + "..." + job.stacktrace = None + job.settings = Settings() + + clean_jobs.append(job) + + # Store the JobContent, with reduced JobItems, on the main vertex. + # This still has the actives and list of job history. + jobs.job_history = clean_jobs + jobs_vertex.add_data( + content=jobs + ) + + ts = time() + self.dag.save() + self.logger.debug(f"jobs took {time()-ts} secs to save") + + self.logger.debug(" finished saving") + + def _remove_old_history(self, job_history: List[JobItem], limit: int) -> List[JobItem]: + + self.logger.debug("clean up job history and migrate discovery delta") + + # The oldest will be first (lower start_ts, older the job) + job_history = sorted(job_history, key=lambda j: j.start_ts) + + # Limit the number of job history to the last few jobs. + while (len(list(job_history))) > limit: + job = job_history[0] + self.logger.debug(f"remove job {job.job_id} item") + job_history = job_history[1:] + job_vertex = self.dag.walk_down_path(f"{self.data_path}/{job.job_id}") + if job_vertex is not None: + self.logger.debug(f"remove job {job.job_id} vertex") + job_vertex.delete() + + self.logger.debug(f"found {len(job_history)} items in job history") + + return job_history + + def start(self, settings: Optional[Settings] = None, resource_uid: Optional[str] = None, + conversation_id: Optional[str] = None) -> str: + + """ + Start a discovery job. + """ + + self.logger.debug("starting a discovery job") + + if settings is None: + settings = Settings() + else: + # We want to remove the user_map, because it may contain a lot of data; It might break the graph. + # Make a copy of settings, remove the user map, and save this version of settings. + settings = copy.deepcopy(settings) + settings.user_map = None + + jobs = self.get_jobs() + + # The -1 is for the new job we are going to add. When done we are done starting the job have the limit. + job_history = self._remove_old_history(jobs.job_history, limit=Jobs.HISTORY_LIMIT - 1) + + new_job = JobItem( + job_id="JOB" + base64.urlsafe_b64encode(os.urandom(8)).decode().rstrip('='), + start_ts=int(time()), + settings=settings, + resource_uid=resource_uid, + conversation_id=conversation_id, + + # Create a blank discovery delta. + # Commander has a bug where it needs at least one. + # It will be overwritten when the job is finished. + delta=DiscoveryDelta() + ) + jobs.active_job_id = new_job.job_id + job_history.append(new_job) + jobs.job_history = job_history + + self.set_jobs(jobs) + + return new_job.job_id + + def get_job_content(self) -> JobContent: + jobs = self.dag.walk_down_path(path=self.data_path) + return jobs.content_as_object(JobContent) + + def get_job(self, job_id) -> Optional[JobItem]: + jobs = self.get_jobs() + for job in jobs.job_history: + if job.job_id == job_id: + + job_vertex = self.dag.walk_down_path(path=f"{self.data_path}/{job.job_id}") + if job_vertex is not None: + + # Get the job item from the job vertex DATA edge. + # Replace the one from the job history if we have it. + try: + job = job_vertex.content_as_object(JobItem) + except Exception as err: + self.logger.debug(f"could not find job item on job vertex, use job histry entry: {err}") + + # If the job delta is None, check to see if it chunked as vertices. + delta_lookup = {} + vertices = job_vertex.has_vertices() + self.logger.debug(f"found {len(vertices)} delta vertices") + for vertex in vertices: + edge = vertex.get_edge(job_vertex, edge_type=EdgeType.KEY) + delta_lookup[int(edge.path)] = vertex + + json_value = "" + # Sort numerically increasing and then append their content. + # This will re-assemble the JSON + for key in sorted(delta_lookup): + json_value += delta_lookup[key].content_as_str + if json_value != "": + self.logger.debug(f"delta content length is {len(json_value)}") + job.delta = DiscoveryDelta.model_validate_json(json_value) + else: + self.logger.debug("could not find job vertex") + + # If settings was not set, then set it the default. + if job.settings is None: + job.settings = Settings() + + return job + return None + + def error(self, job_id: str, error: Optional[str], stacktrace: Optional[str] = None): + + self.logger.debug("flag discovery job as error") + + jobs = self.get_jobs() + for job in jobs.job_history: + if job.job_id == job_id: + logging.debug("found job to add error message") + job.end_ts = int(time()) + job.success = False + job.error = error + job.stacktrace = stacktrace + + self.set_jobs(jobs) + + def finish(self, job_id: str, sync_point: int, delta: DiscoveryDelta): + + self.logger.debug("finish discovery job") + + jobs = self.get_jobs() + for job in jobs.job_history: + if job.job_id == job_id: + self.logger.debug("found job to finish") + job.sync_point = sync_point + job.end_ts = int(time()) + job.success = True + job.delta = delta + + self.set_jobs(jobs) + + def cancel(self, job_id): + + self.logger.debug("cancel discovery job") + + jobs = self.get_jobs() + for job in jobs.job_history: + if job.job_id == job_id: + self.logger.debug("found job to cancel") + if job.end_ts is None: + job.end_ts = int(time()) + jobs.active_job_id = None + self.set_jobs(jobs) + + @property + def history(self) -> List[JobItem]: + jobs = self.get_jobs() + return jobs.job_history + + @property + def job_id_list(self) -> List[str]: + return [j.job_id for j in self.history] + + @property + def current_job(self) -> Optional[JobItem]: + """ + Get the current job + + The current job is the oldest unprocessed job + """ + jobs = self.get_jobs() + if jobs.active_job_id is None: + return None + return self.get_job(jobs.active_job_id) + + def __str__(self): + def _h(i: JobItem): + return f"Job ID: {i.job_id}, {i.success}, {i.sync_point} " + + ret = "HISTORY\n" + for item in self.history: + ret += _h(item) + return ret + + def to_dot(self, graph_format: str = "svg", show_version: bool = True, show_only_active_vertices: bool = True, + show_only_active_edges: bool = True, graph_type: str = "dot"): + + try: + mod = importlib.import_module("graphviz") + except ImportError: + raise Exception("Cannot to_dot(), graphviz module is not installed.") + + dot = getattr(mod, "Digraph")(comment=f"DAG for Jobs", format=graph_format) + + if graph_type == "dot": + dot.attr(rankdir='RL') + elif graph_type == "twopi": + dot.attr(layout="twopi") + dot.attr(ranksep="10") + dot.attr(ratio="auto") + else: + dot.attr(layout=graph_type) + + self.logger.debug(f"have {len(self.dag.all_vertices)} vertices") + for v in self.dag.all_vertices: + + if show_only_active_vertices is True and v.active is False: + continue + + fillcolor = "white" + tooltip = "" + + for edge in v.edges: + + color = "grey" + style = "solid" + + # To reduce the number of edges, only show the active edges + if edge.active: + color = "black" + style = "bold" + elif show_only_active_edges: + continue + + # If the vertex is not active, gray out the DATA edge + if edge.edge_type == EdgeType.DATA: + if not v.active: + color = "grey" + elif v.has_data: + + try: + data = v.content_as_object(JobContent) # type: JobContent + if data.active_job_id is not None: + tooltip = f"Current Job Id: {data.active_job_id}\n"\ + f"History: \n" + for item in data.job_history: + tooltip += f" * {item.job_id}, {item.sync_point}, {item.start_ts_str}, "\ + f"{item.delta}, {item.error}\n" + fillcolor = "#FFFF00" + else: + fillcolor = "#CFCFFF" + except (Exception,): + try: + data = v.content_as_object(JobItem) # type: JobItem + if data.job_id is not None: + tooltip = f"Job Id: {data.job_id}\n" \ + f"Resource ID: {data.resource_uid}\n" \ + f"Start Ts: {data.start_ts}\n" \ + f"End Ts: {data.end_ts}\n" \ + f"Converstion ID: {data.conversation_id}\n" \ + f"Error: {data.error}\n" \ + f"Stack Trace: {data.stacktrace}\n" \ + f"Sync Point: {data.sync_point}\n" + fillcolor = "#FFFFF0" + else: + fillcolor = "#CFCFFF" + except (Exception,): + fillcolor = "#CFCFFF" + + if edge.edge_type == EdgeType.DELETION: + style = "dotted" + + label = DAG.EDGE_LABEL.get(edge.edge_type) + if label is None: + label = "UNK" + if edge.path is not None and edge.path != "": + label += f"\\npath={edge.path}" + if show_version: + label += f"\\nv={edge.version}" + + # tail, head (arrow side), label, ... + dot.edge(v.uid, edge.head_uid, label, style=style, fontcolor=color, color=color) + + shape = "ellipse" + + color = "black" + if not v.active: + fillcolor = "grey" + + label = f"uid={v.uid}" + dot.node(v.uid, label, color=color, fillcolor=fillcolor, style="filled", shape=shape, tooltip=tooltip) + + return dot diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/process.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/process.py new file mode 100644 index 00000000..338bff23 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/process.py @@ -0,0 +1,1568 @@ +from __future__ import annotations +import logging +import os +from .constants import PAM_DIRECTORY, PAM_USER, VERTICES_SORT_MAP, LOCAL_USER, PAM_CONFIGURATIONS +from .jobs import Jobs +from .infrastructure import Infrastructure +from .record_link import RecordLink +from .user_service import UserService +from .rule import Rules +from .dag_types import (DiscoveryObject, DiscoveryUser, RecordField, RuleActionEnum, UserAcl, + PromptActionEnum, PromptResult, BulkRecordAdd, BulkRecordConvert, BulkProcessResults, + DirectoryInfo, NormalizedRecord) +from .dag_utils import value_to_boolean, split_user_and_domain +from .dag_sort import sort_infra_vertices +from .dag_types import EdgeType +from .dag_crypto import bytes_to_urlsafe_str +import hashlib +from typing import Any, Callable, List, Optional, Union, TYPE_CHECKING + + +if TYPE_CHECKING: + from .dag_vertex import DAGVertex + DirectoryResult = Union[DirectoryInfo, List] + DirectoryUserResult = Union[NormalizedRecord, DAGVertex] + + +class QuitException(Exception): + """ + This exception used when the user wants to stop processing of the results, before the end. + """ + pass + + +class UserNotFoundException(Exception): + """ + We could not find the user. + """ + pass + + +class DirectoryNotFoundException(Exception): + """ + We could not find the directory. + """ + pass + + +class NoDiscoveryDataException(Exception): + """ + This exception is thrown when there is no discovery data. + This is not an error. + There is just nothing to do. + """ + pass + + +class Process: + + """ + Process discovery results + + While this class update the PAM/record linking graph, it does not save it. + + """ + + # Warn when bulk record lists exceed this size (potential memory issue) + BULK_LIST_WARNING_THRESHOLD = 10000 + # Hard limit for bulk record lists (safety mechanism) + BULK_LIST_MAX_SIZE = 50000 + + def __init__(self, record: Any, job_id: str, logger: Optional[Any] = None, debug_level: int = 0, **kwargs): + self.job_id = job_id + self.record = record + + env_debug_level = os.environ.get("PROCESS_GS_DEBUG_LEVEL") + if env_debug_level is not None: + debug_level = int(env_debug_level) + + # Remember what passed in a kwargs + self.passed_kwargs = kwargs + + self.jobs = Jobs(record=record, logger=logger, debug_level=debug_level, **kwargs) + self.job = self.jobs.get_job(self.job_id) + + # These are lazy load, so the graph is not loaded here. + self.infra = Infrastructure(record=record, logger=logger, + debug_level=debug_level, + fail_on_corrupt=False, + **kwargs) + self.record_link = RecordLink(record=record, logger=logger, debug_level=debug_level, **kwargs) + self.user_service = UserService(record=record, logger=logger, debug_level=debug_level, **kwargs) + + # This is the root UID for all graphs; get it from one of them. + self.configuration_uid = self.jobs.dag.uid + + if logger is None: + logger = logging.getLogger() + self.logger = logger + self.debug_level = debug_level + + self.logger.debug(f"discovery process is using configuration uid {self.configuration_uid}") + + def close(self): + """ + Clean up resources held by this Process instance. + Releases all DAG instances and connections to prevent memory leaks. + """ + + if self.jobs: + self.jobs.close() + self.jobs = None + if self.infra: + self.infra.close() + self.infra = None + if self.record_link: + self.record_link.close() + self.record_link = None + if self.user_service: + self.user_service.close() + self.user_service = None + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures cleanup.""" + self.close() + return False + + def __del__(self): + self.close() + + @staticmethod + def get_key_field(record_type: str) -> str: + return VERTICES_SORT_MAP.get(record_type)["key"] + + @staticmethod + def set_user_based_ids(configuration_uid: str, content: DiscoveryObject, parent_vertex: Optional[DAGVertex] = None): + + if configuration_uid is None: + raise ValueError("The configuration UID is None when trying to create an id and UID for user.") + + if content.item.user is None: + raise Exception("The user name is blank. Cannot make an ID for the user.") + + parent_content = DiscoveryObject.get_discovery_object(parent_vertex) + object_id = content.item.user + if "\\" in content.item.user: + # Remove the domain name from the user. + # [0] will be the domain, [1] will be the user. + object_id = object_id.split("\\")[1] + if parent_content.record_type == PAM_DIRECTORY: + domain = parent_content.name + if not object_id.endswith(domain): + object_id += f"@{domain}" + else: + object_id += parent_content.id + + content.id = object_id + + uid = configuration_uid + content.object_type_value + object_id + m = hashlib.sha256() + m.update(uid.lower().encode()) + + content.uid = bytes_to_urlsafe_str(m.digest()[:16]) + + def populate_admin_content_ids(self, content: DiscoveryObject, parent_vertex: Optional[DAGVertex] = None): + + """ + Populate the id and uid attributes for content. + """ + + return self.set_user_based_ids(self.configuration_uid, content, parent_vertex) + + def get_keys_for_vertex(self, vertex: DAGVertex) -> List[str]: + """ + For the vertex + :param vertex: + :return: + """ + + content = DiscoveryObject.get_discovery_object(vertex) + key_field = self.get_key_field(content.record_type) + keys = [] + if key_field == "host_port": + if content.item.port is not None: + if content.item.host is not None: + keys.append(f"{content.item.host}:{content.item.port}".lower()) + if content.item.ip is not None: + keys.append(f"{content.item.ip}:{content.item.port}".lower()) + elif key_field == "host": + if content.item.host is not None: + keys.append(content.item.host.lower()) + if content.item.ip is not None: + keys.append(content.item.ip.lower()) + elif key_field == "user": + if content.parent_record_uid is not None: + if content.item.user is not None: + keys.append(f"{content.parent_record_uid}:{content.item.user}".lower()) + if content.item.dn is not None: + keys.append(f"{content.parent_record_uid}:{content.item.dn}".lower()) + return keys + + def _update_with_record_uid(self, record_cache: dict, current_vertex: DAGVertex): + + # If the current vertex is not active, then return. + # It won't have a DATA edge. + if not current_vertex.active: + return + + for vertex in current_vertex.has_vertices(): + + # Skip if the vertex is not active. + # It won't have a DATA edge. + if vertex.active is False or vertex.has_data is False: + continue + + # Don't worry about "item" class type + content = DiscoveryObject.get_discovery_object(vertex) + + # If we are ignoring the object, then skip. + if content.action_rules_result == RuleActionEnum.IGNORE.value or content.ignore_object is True: + continue + elif content.record_uid is not None: + cache_keys = self.get_keys_for_vertex(vertex) + for key in cache_keys: + + # If we find an item in the cache, update the vertex with the record UID + if key in record_cache.get(content.record_type): + content.record_uid = record_cache.get(content.record_type).get(key) + vertex.add_data(content) + break + + # Process the vertices that belong to the current vertex. + self._update_with_record_uid( + record_cache=record_cache, + current_vertex=vertex, + ) + + @staticmethod + def _prepare_record(record_prepare_func: Callable, + bulk_add_records: List[BulkRecordAdd], + content: DiscoveryObject, + parent_content: DiscoveryObject, + vertex: DAGVertex, + admin_uid: Optional[str] = None, + context: Optional[Any] = None) -> DiscoveryObject: + """ + Prepare a record to be added. + + :param record_prepare_func: Function to call to prepare a record to be created. + :param bulk_add_records: List of records to be added. + :param content: Discovery content of the current discovery item. + :param parent_content: Discovery content of the parent of the current discovery item. + :param vertex: Infrastructure vertex of the current discovery item. + :params admin_uid: If resource, if there is an admin, this is the UID of that PAM User + :param context: The context; dictionary of random instances. + :return: + """ + + record_to_be_added, record_uid = record_prepare_func( + content=content, + context=context + ) + if record_to_be_added is None: + raise Exception("Did not get prepare record.") + if record_uid is None: + raise Exception("The prepared record did not contain a record UID.") + + parent_record_uid = parent_content.record_uid + if parent_content.object_type_value == "providers": + parent_record_uid = None + bulk_add_records.append( + BulkRecordAdd( + title=content.title, + record=record_to_be_added, + record_type=content.record_type, + record_uid=record_uid, + parent_record_uid=parent_record_uid, + shared_folder_uid=content.shared_folder_uid, + admin_uid=admin_uid + ) + ) + + content.record_uid = record_uid + content.parent_record_uid = parent_content.record_uid + vertex.add_data(content) + + return content + + def _default_acl(self, + discovery_vertex: DAGVertex, + content: DiscoveryObject, + discovery_parent_vertex: DAGVertex) -> UserAcl: + # Check to see if this user already belongs to another record vertex, or belongs to this one. + belongs_to = False + is_admin = False + is_iam_user = False + + parent_content = DiscoveryObject.get_discovery_object(discovery_parent_vertex) + + # User record the already exists. + # This means the vertex has a record UID, doesn't mean it exists in the vault. + # It may have been added during this processing. + if content.record_exists is False: + belongs_to = True + + # Is this user the admin for the resource? + if parent_content.access_user is not None: + # If this user record's user matches the user that was used to log into the parent resource, + # then this user is the admin for the parent resource. + if parent_content.access_user.user == content.item.user: + is_admin = True + + # User record does not exist. + else: + belongs_to_record_vertex = self.record_link.acl_has_belong_to_vertex(discovery_vertex) + + # If the user doesn't belong to any other vertex, it will be long the parent resource. + if belongs_to_record_vertex is None: + self.logger.debug(" user vertex does not belong to another resource vertex") + belongs_to = True + + else: + parent_record_vertex = self.record_link.get_record_uid(discovery_parent_vertex) + if parent_record_vertex is not None: + if belongs_to_record_vertex == parent_record_vertex: + self.logger.debug(" user vertex already belongs to the parent resource vertex") + belongs_to = True + else: + self.logger.debug(" user vertex does not belong to any other resource vertex") + + # If the parent resource is a provider, then this user is an IAM user. + if parent_content.object_type_value == "providers": + is_iam_user = True + + acl = UserAcl.default() + acl.belongs_to = belongs_to + acl.is_admin = is_admin + acl.is_iam_user = is_iam_user + + return acl + + def _directory_exists(self, domain: str, directory_info_func: Callable, context: Any) -> Optional[DirectoryResult]: + + """ + This method will find the directory in the Infrastructure graph or in the Vault. + + If the domain contains more than one DC, the domain will be split and the full DC will be search and then + the first DC. + For example, if EXAMPLE.COM is passed in for the domain, EXAMPLE.COM and EXAMPLE will be searched for. + + The Infrastructure graph will be searched first. + If nothing is found, the Vault will be searched. + + If the directory is found in the graph, a list if directory vertices will be returned. + If the directory is found in the Vault, a DirectoryInfo instance will be returned. + If nothing is found, None is returned. + + The returned results can be passed to the _find_directory_user method. + + """ + + domains = [domain] + if "." in domains: + domains.append(domain.split(".")[0]) + + self.logger.debug(f"search for directories: {', '.join(domains)}") + + # Some providers provider directory type services. + # They can also provide multiple domains + provider_vertices = self.infra.dag.search_content({ + "record_type": ["pamAzureConfiguration", "pamDomainConfiguration"], + }, ignore_case=True) + found_provider_directories = [] + for provider_vertex in provider_vertices: + content = DiscoveryObject.get_discovery_object(provider_vertex) + found = False + for domain in domains: + for provider_domain in content.item.info.get("domains", []): + if domain.lower() in provider_domain.lower(): + found = True + break + if found: + break + if found: + found_provider_directories.append(provider_vertex) + if len(found_provider_directories) > 0: + return found_provider_directories + + # Check the graph first. + # `search_content` does an "is in" type match; so subdomains should match a full domain + # pamDomainConfiguration is an edge case because it's name in the record is the domain name. + for domain_name in domains: + directories = self.infra.dag.search_content({ + "record_type": ["pamDirectory", "pamDomainConfiguration"], + "name": domain_name + }, ignore_case=True) + + self.logger.debug(f"found {len(directories)} directories in the graph") + + # If we found directories, return the list of directory vertices. + if len(directories) > 0: + # Return vertices + return directories + + # Check the vault secondly. + for domain_name in domains: + info = directory_info_func(domain=domain_name, skip_users=False, context=context) + if info is not None: + # If we found directories in the Vault, then return directory info + # This will be an instance of DirectoryInfo + return info + + return None + + def _find_directory_user(self, + results: DirectoryResult, + record_lookup_func: Callable, + context: Any, + find_user: Optional[str] = None, + find_dn: Optional[str] = None) -> Optional[DirectoryUserResult]: + + # If the passed in results were a DirectoryInfo then check the Vault for users. + if isinstance(results, DirectoryInfo): + self.logger.debug("search for directory user from vault records") + self.logger.debug(f"have {len(results.directory_user_record_uids)} users") + for user_record_id in results.directory_user_record_uids: + record = record_lookup_func(record_uid=user_record_id, context=context) # type: NormalizedRecord + if record is not None: + found = None + self.logger.debug(f"find user {find_user}, dn {find_dn}") + if find_user is not None: + found = record.find_user(find_user) + if found is None and find_dn is not None: + found = record.find_dn(find_dn) + return found + return None + + # Else it was a list of directory vertices, check its children for the users. + else: + self.logger.debug("search for directory user from the graph") + for directory_vertex in results: # type: DAGVertex + for user_vertex in directory_vertex.has_vertices(): + user_content = DiscoveryObject.get_discovery_object(user_vertex) + + # We should only have pamUser vertices. + if user_content.record_type != PAM_USER: + self.logger.debug(f"in find directory user, a vertex {user_vertex.uid} was not a pamUser, " + f"was {user_content.record_type}.") + continue + + found_vertex = None + if find_user is not None: + user, domain = split_user_and_domain(find_user) + if user_content.item.user.lower() == user.lower(): + found_vertex = user_vertex + elif user_content.item.user.lower() == find_user.lower(): + found_vertex = user_vertex + elif find_dn is not None: + if user_content.item.dn.lower() == find_dn.lower(): + found_vertex = user_vertex + + if found_vertex is not None: + return found_vertex + return None + + def _record_link_directory_users(self, + directory_vertex: DAGVertex, + directory_content: DiscoveryObject, + directory_info_func: Callable, + context: Optional[Any] = None): + + """ + Link user record to directory when adding a new directory. + + When adding a new directory, there may be other directories for the same domain. + We need to link existing directory users, of the same domain, to this new directory. + + """ + + self.logger.debug(f"resource is directory; connect users to this directory for {directory_vertex.uid}") + + record_link = context.get("record_link") # type: RecordLink + + # Get the directory user record UIDs from the vault that belong to directories using the same domain. + directory_info = directory_info_func( + domain=directory_content.name, + context=context + ) # type: DirectoryInfo + if directory_info is None: + self.logger.debug("there were no directory record for this domain") + directory_info = DirectoryInfo() + + user_record_uids = directory_info.directory_user_record_uids + + self.logger.debug(f"found {len(directory_info.directory_user_record_uids)} users" + f"from {len(directory_info.directory_record_uids)} directories.") + + # Check our current discovery data. + # This is a delta, it will not contain discovery from prior runs. + # This will only contain objects in this run. + # Make sure the object is a directory and the domain is the same. + # Also make sure there is a record UID; it might not be added yet. + self.logger.debug("finding directories in discovery vertices") + for parent_vertex in directory_vertex.belongs_to_vertices(): + self.logger.debug(f"find directories under {parent_vertex.uid}") + for other_directory_vertex in parent_vertex.has_vertices(): + if other_directory_vertex.uid == directory_vertex.uid: + self.logger.debug(" skip this directory, it's the current one") + continue + other_directory_content = DiscoveryObject.get_discovery_object(other_directory_vertex) + self.logger.debug(f"{other_directory_content.record_type}, {other_directory_content.name}, " + f"{other_directory_content.uid}, {other_directory_content.record_uid}") + if (other_directory_content.record_type == PAM_DIRECTORY + and other_directory_content.name == directory_content.name + and other_directory_content.record_uid is not None): + self.logger.debug(f"check {other_directory_content.uid} for users") + for user_vertex in other_directory_vertex.has_vertices(): + user_content = DiscoveryObject.get_discovery_object(user_vertex) + self.logger.debug(f" * {user_vertex.uid}, {user_content.record_uid}") + if user_content.record_uid is not None and user_content.record_uid not in user_record_uids: + user_record_uids.append(user_content.record_uid) + del user_content + del other_directory_content + + self.logger.debug(f"found {len(user_record_uids)} user to connect to directory") + + # Make sure there is a link from the user record to the directory record. + # We also might need to make a KEY edge from the user to the directory if one does not exist. + for record_uid in user_record_uids: + if record_link.get_acl(record_uid, directory_content.record_uid) is None: + record_link.belongs_to(record_uid, directory_content.record_uid, acl=UserAcl.default()) + + # Check if the user vertex has a KEY edge to the directory_vertex. + found_vertices = directory_vertex.dag.search_content({"record_uid": record_uid}) + if len(found_vertices) == 1: + user_vertex = found_vertices[0] + if user_vertex.get_edge(directory_vertex, EdgeType.KEY) is None: + self.logger.debug(f"adding a KEY edge from the user {user_vertex.uid} to {directory_vertex.uid}") + user_vertex.belongs_to(directory_vertex, EdgeType.KEY) + else: + self.logger.debug("could not find user vertex") + + def _find_admin_directory_user(self, + domain: str, + admin_acl: UserAcl, + directory_info_func: Callable, + record_lookup_func: Callable, + context: Any, + user: Optional[str] = None, + dn: Optional[str] = None) -> Optional[str]: + + # Check any directories for the domain exist. + results = self._directory_exists(domain=domain, + directory_info_func=directory_info_func, + context=context) + + if results is not None: + # Find the user (clean of domain) or DN in the found directories. + directory_user = self._find_directory_user(results=results, + record_lookup_func=record_lookup_func, + context=context, + find_user=user, + find_dn=dn) + if directory_user is not None: + + # If we got a normalized record, then a Vault record exists. + # No need to create a record, just link, belongs_to is False + # Since we are using records, just the belongs_to method instead of + # discovery_belongs_to. + if isinstance(directory_user, NormalizedRecord): + admin_acl.belongs_to = False + return directory_user.record_uid + else: + admin_content = DiscoveryObject.get_discovery_object(directory_user) + + # If not a PAM User, then this is bad. + if admin_content.record_type != PAM_USER: + self.logger.warning( + f"found record type {admin_content.record_type} instead of " + f"pamUser for record UID {admin_content.record_uid}") + return None + + # If the record UID exists, then connect the directory user to the + # resource. + if admin_content.record_uid is not None: + admin_acl.belongs_to = False + return admin_content.record_uid + + return None + else: + raise UserNotFoundException(f"Could not find the directory user in domain {domain}") + else: + raise DirectoryNotFoundException(f"Could not find the directory for domain {domain}") + + def _process_auto_add_level(self, + current_vertex: DAGVertex, + bulk_add_records: List[BulkRecordAdd], + bulk_convert_records: List[BulkRecordConvert], + record_lookup_func: Callable, + record_prepare_func: Callable, + directory_info_func: Callable, + record_cache: dict, + context: Optional[Any] = None): + + """ + This method will add items to the bulk_add_records queue to be added by the client. + + These are items where the rule engine has flagged them to be added. + + :param current_vertex: The current/parent discovery vertex. + :param bulk_add_records: List of records to be added. + :param bulk_convert_records: List of existing records to be covert to this gateway. + :params record_lookup_func: A function to lookup records to see if they exist. + :param record_prepare_func: Function to convert content into an unsaved record. + :param directory_info_func: Function to lookup directories. + :param record_cache: + :param context: Client context; could be anything. + :return: + """ + + if not current_vertex.active: + self.logger.debug(f"vertex {current_vertex.uid} is not active, skip") + return + + # Check if this vertex has a record. + # We cannot add child vertices to a vertex that does not have a record. + current_content = DiscoveryObject.get_discovery_object(current_vertex) + if current_content.record_uid is None: + self.logger.debug(f"vertex {current_content.uid} does not have a record id") + return + + self.logger.debug(f"Current Vertex: {current_content.record_type}, {current_vertex.uid}, " + f"{current_content.name}") + + # Sort all the vertices under the current vertex. + # Return a dictionary where the record type is the key. + # The value will be an array of vertices of the specific record type. + record_type_to_vertices_map = sort_infra_vertices(current_vertex, logger=self.logger) + + # Process the record type by their map order in ascending order. + for record_type in sorted(record_type_to_vertices_map, key=lambda i: VERTICES_SORT_MAP[i]['order']): + self.logger.debug(f" processing {record_type}") + for vertex in record_type_to_vertices_map[record_type]: + + child_content = DiscoveryObject.get_discovery_object(vertex) + self.logger.debug(f" child vertex {vertex.uid}, {child_content.name}") + + # If we are going to add an admin user, this is the default ACL + # This is for the smart add feature + admin_acl = UserAcl.default() + admin_acl.is_admin = True + + # This ACL is None for resource, and populated for users. + default_acl = None + if child_content.record_type == PAM_USER: + default_acl = self._default_acl( + discovery_vertex=vertex, + content=child_content, + discovery_parent_vertex=current_vertex) + + # Check for a vault record, if it exists. + # Default to the DAG content. + # Check the bulk_add_records list, to make sure it is not in the list of record we are about to add. + # We are doing this because the record might be an active directory user, that we have + # not created a record for yet, however it might have been assigned a record UID from a prior prompt. + + existing_record = child_content.record_exists + if record_lookup_func is not None: + check_the_vault = True + for item in bulk_add_records: + if item.record_uid == child_content.record_uid: + self.logger.debug(f" record is in the bulk add list, do not check the vault if exists") + check_the_vault = False + break + if check_the_vault: + existing_record = record_lookup_func(record_uid=child_content.record_uid, + context=context) is not None + self.logger.debug(f" record exists in the vault: {existing_record}") + else: + self.logger.debug(f" record lookup function not defined, record existing: {existing_record}") + + # Determine if we are going to add the item. + # If the item has a record UID already, we don't need to add. + add_record = False + if (child_content.record_exists is False + and child_content.action_rules_result == RuleActionEnum.ADD.value): + self.logger.debug(f" vertex {vertex.uid} had an ADD result for the rule engine, auto add") + add_record = True + + if add_record: + + self.logger.debug(f"adding resource record") + + # For a resource, the ACL will be None. + # It will a UserAcl if a user. + self.record_link.belongs_to(child_content.record_uid, current_content.record_uid, acl=default_acl) + + admin_uid = None + # If the rules have set the admin_uid then connect the user to the resource. + if (child_content.admin_uid is not None + and child_content.record_type != PAM_USER + and record_lookup_func is not None): + + self.logger.debug("the admin UID has been set for this resource") + + admin_record = record_lookup_func(record_uid=child_content.admin_uid, + context=context) # type: NormalizedRecord + if admin_record is not None and admin_record.record_type == PAM_USER: + self.logger.debug("was able to find the admin record, connect to resource") + admin_uid = child_content.admin_uid + admin_acl.is_admin = True + self.record_link.belongs_to(child_content.admin_uid, child_content.record_uid, + acl=admin_acl) + else: + self.logger.info(f"The PAM User record {child_content.admin_uid} does not exists. " + "Cannot set the administrator for an auto added " + f"record {child_content.title}.") + + # The record could be a resource or user record. + self._prepare_record( + record_prepare_func=record_prepare_func, + bulk_add_records=bulk_add_records, + content=child_content, + parent_content=current_content, + vertex=vertex, + context=context, + admin_uid=admin_uid + ) + if child_content.record_uid is None: + raise Exception(f"the record uid is blank for {child_content.description} after prepare") + + # If the record type is a PAM User, we don't need to go deeper. + # In the future we might need to change if PAM User becomes a branch and not a leaf. + # This is for safety reasons + if child_content.record_type != PAM_USER: + # Process the vertices that belong to the current vertex. + self._process_auto_add_level( + current_vertex=vertex, + bulk_add_records=bulk_add_records, + bulk_convert_records=bulk_convert_records, + record_lookup_func=record_lookup_func, + record_prepare_func=record_prepare_func, + directory_info_func=directory_info_func, + record_cache=record_cache, + context=context + ) + + self.logger.debug(f" finished auto add processing {record_type}") + self.logger.debug(f" Finished auto add current Vertex: {current_vertex.uid}, {current_content.name}") + + @staticmethod + def _apply_admin_uid(bulk_add_records: List[BulkRecordAdd], + resource_uid: str, + admin_uid: str): + + for item in bulk_add_records: + if item.record_uid == resource_uid: + item.admin_uid = admin_uid + break + + def _process_level(self, + current_vertex: DAGVertex, + bulk_add_records: List[BulkRecordAdd], + bulk_convert_records: List[BulkRecordConvert], + record_lookup_func: Callable, + prompt_func: Callable, + prompt_admin_func: Callable, + record_prepare_func: Callable, + directory_info_func: Callable, + record_cache: dict, + item_count: int = 0, + items_left: int = 0, + indent: int = 0, + context: Optional[Any] = None): + + """ + This method will walk the user through discovery delta objects. + + At this point, we only have the delta objects from the graph. + We do not have the full graph. + + :param current_vertex: The current/parent discovery vertex. + :param bulk_add_records: List of records to be added. + :param bulk_convert_records: List of existing records to be covert to this gateway. + :param prompt_func: Function to call for user prompt. + :param record_prepare_func: Function to convert content into an unsaved record. + :param indent: Amount to indent text. + :param context: Client context; could be anything. + :return: + """ + + if not current_vertex.active: + self.logger.debug(f"vertex {current_vertex.uid} is not active, skip") + return + + # Check if this vertex has a record. + # We cannot add child vertices to a vertex that does not have a record. + current_content = DiscoveryObject.get_discovery_object(current_vertex) + if current_content.record_uid is None: + self.logger.debug(f"vertex {current_content.uid} does not have a record id") + return + + self.logger.debug(f"Current Vertex: {current_content.record_type}, {current_vertex.uid}, " + f"{current_content.name}") + + # Sort all the vertices under the current vertex. + # Return a dictionary where the record type is the key. + # The value will be an array of vertices of the specific record type. + record_type_to_vertices_map = sort_infra_vertices(current_vertex, logger=self.logger) + + # Process the record type by their map order in ascending order. + for record_type in sorted(record_type_to_vertices_map, key=lambda i: VERTICES_SORT_MAP[i]['order']): + self.logger.debug(f" processing {record_type}") + for vertex in record_type_to_vertices_map[record_type]: + + child_content = DiscoveryObject.get_discovery_object(vertex) + self.logger.debug(f" child vertex {vertex.uid}, {child_content.name}") + + default_acl = None + if child_content.record_type == PAM_USER: + default_acl = self._default_acl( + discovery_vertex=vertex, + content=child_content, + discovery_parent_vertex=current_vertex) + + # Check for a vault record, if it exists. + # Default to the DAG content. + # Check the bulk_add_records list, to make sure it is not in the list of record we are about to add. + # We are doing this because the record might be an active directory user, that we have + # not created a record for yet, however it might have been assigned a record UID from a prior prompt. + + existing_record = child_content.record_exists + if record_lookup_func is not None: + check_the_vault = True + for item in bulk_add_records: + if item.record_uid == child_content.record_uid: + self.logger.debug(f" record is in the bulk add list, do not check the vault if exists") + check_the_vault = False + break + if check_the_vault: + existing_record = record_lookup_func(record_uid=child_content.record_uid, + context=context) is not None + self.logger.debug(f" record exists in the vault: {existing_record}") + else: + self.logger.debug(f" record lookup function not defined, record existing: {existing_record}") + + # If we have a record UID, the record exists; we don't need to prompt the user. + # If a user, we do want to make sure an ACL exists between this user and the resource. + if existing_record is True: + self.logger.debug(f" record already exists.") + # Don't continue since we might want to recurse into its children. + + # If the rule engine result is to ignore this object, then continue. + # This normally would not happen since discovery wouldn't add the object. + # However, make sure we skip any object where the rule engine action is to ignore the object. + elif child_content.action_rules_result == RuleActionEnum.IGNORE.value: + self.logger.debug(f" vertex {vertex.uid} had a IGNORE result for the rule engine, " + "skip processing") + # If the rule engine result is to ignore this object, then continue. + continue + + # If this flag is set, the user set the ignore_object flag when prompted. + elif child_content.ignore_object: + self.logger.debug(f" vertex {vertex.uid} was flagged as ignore, skip processing") + # If the ignore_object flag is set, then continue. + continue + + # If the record doesn't exist, then prompt the user. + else: + self.logger.debug(f" vertex {vertex.uid} had an PROMPT result, prompt user") + + # For user record, check if the resource record has an admin. + # If not, prompt the user if they want to add this user as the admin. + # The returned ACL will have the is_admin flag set to True if they do. + resource_has_admin = False + if child_content.record_type == PAM_USER: + resource_has_admin = (self.record_link.get_admin_record_uid(current_content.record_uid) + is not None) + self.logger.debug(f"resource has an admin is {resource_has_admin}") + + # If the current resource does not allow an admin, then it has and admin, it's just controlled by + # us. + # This is going to be a resource record, or a configuration record. + if hasattr(current_content.item, "allows_admin"): + if not current_content.item.allows_admin: + self.logger.debug(f"resource allows an admin is {current_content.item.allows_admin}") + resource_has_admin = True + else: + self.logger.debug(f"resource type {current_content.record_type} does not have " + "allows_admin attr") + + result = prompt_func( + vertex=vertex, + parent_vertex=current_vertex, + content=child_content, + acl=default_acl, + resource_has_admin=resource_has_admin, + indent=indent, + item_count=item_count, + items_left=items_left, + context=context) # type: PromptResult + + if result.action == PromptActionEnum.IGNORE: + self.logger.debug(f" vertex {vertex.uid} is being ignored from prompt") + result.content.ignore_object = True + + action_rule_item = Rules.make_action_rule_from_content( + content=result.content, + action=RuleActionEnum.IGNORE + ) + + # Add a rule to ignore this object when doing future discovery. + rules = Rules(record=self.record, **self.passed_kwargs) + rules.add_rule(action_rule_item) + + # Even though we are ignoring the object, we will still add it to the infrastructure graph. + # This is user selected ignored, not from the rule engine. + # vertex.belongs_to(current_vertex, EdgeType.KEY) + vertex.add_data(result.content) + + elif result.action == PromptActionEnum.ADD: + self.logger.debug(f" vertex {vertex.uid} is being added from prompt") + + # Use the content from the prompt. + # The user may have modified it. + add_content = result.content + acl = result.acl + + # If the current content is a PAM configuration and the child content/add content is a PAM User; + # then the user is an IAM user. + if current_content.record_type in PAM_CONFIGURATIONS and add_content.record_type == PAM_USER: + acl.is_iam_user = True + + # The record could be a resource or user record. + # add_content will have record UID after this. + self._prepare_record( + record_prepare_func=record_prepare_func, + bulk_add_records=bulk_add_records, + content=add_content, + parent_content=current_content, + vertex=vertex, + context=context + ) + + admin_uid = None + + # Make a record link. + # The acl will be None if not a pamUser. + self.record_link.discovery_belongs_to(vertex, current_vertex, acl) + + # If the object is NOT a pamUser and the resource allows an admin. + # Prompt the user to create an admin. + should_prompt_for_admin = True + self.logger.debug(f" added record type was {add_content.record_type}") + if (add_content.record_type != PAM_USER and add_content.item.allows_admin is True and + prompt_admin_func is not None): + + self.logger.debug("checking if can add admin") + + # If the rule engine sets the admin UID + if child_content.admin_uid is not None and record_lookup_func is not None: + + self.logger.debug(f"the resource rule set the admin uid to {add_content.admin_uid}") + + admin_record = record_lookup_func(record_uid=add_content.admin_uid, + context=context) # type: NormalizedRecord + if admin_record is not None and admin_record.record_type == PAM_USER: + self.logger.debug("was able to find the admin record, connect to resource") + + admin_uid = add_content.admin_uid + # admin_acl = UserAcl.default() + # admin_acl.is_admin = True + # self.record_link.belongs_to(child_content.admin_uid, child_content.record_uid, + # acl=admin_acl) + should_prompt_for_admin = False + else: + self.logger.info(f"The PAM User record {child_content.admin_uid} does not exists. " + "Cannot set the administrator for an auto added " + f"record {child_content.title}.") + + # This block checks to see if the admin is a directory user that exists. + # We don't want to prompt the user for an admin if we have one already. + elif add_content.access_user is not None and add_content.access_user.user is not None: + + self.logger.debug(" for this resource, credentials were provided.") + self.logger.error(f" {add_content.access_user.user}, {add_content.access_user.dn}, " + f"{add_content.access_user.password}") + + # Check if this user is a directory users, first check the source. + # If local, check the username incase the domain in part of the username. + source = add_content.access_user.source + if add_content.record_type == PAM_DIRECTORY: + source = add_content.name + elif source == LOCAL_USER: + _, domain = split_user_and_domain(add_content.access_user.user) + if domain is not None: + source = domain + + if source != LOCAL_USER: + self.logger.debug(" admin was not a local user, " + f"find user in directory {source}, if exists.") + + acl = UserAcl.default() + acl.is_admin = True + + try: + admin_uid = self._find_admin_directory_user( + domain=source, + admin_acl=acl, + directory_info_func=directory_info_func, + record_lookup_func=record_lookup_func, + context=context, + user=add_content.access_user.user, + dn=add_content.access_user.dn + ) + + if admin_uid is not None: + self.logger.debug(" found directory user admin, connect to resource") + # self.record_link.belongs_to(admin_uid, add_content.record_uid, acl=acl) + should_prompt_for_admin = False + else: + self.logger.debug(" did not find the directory user for the admin, " + "prompt the user") + except DirectoryNotFoundException: + self.logger.debug(f" directory {source} was not found for admin user") + except UserNotFoundException: + self.logger.debug(f" directory user was not found in directory {source}") + + if should_prompt_for_admin: + self.logger.debug(f" prompt for admin user") + admin_uid = self._process_admin_user( + resource_vertex=vertex, + resource_content=add_content, + bulk_add_records=bulk_add_records, + bulk_convert_records=bulk_convert_records, + record_lookup_func=record_lookup_func, + directory_info_func=directory_info_func, + prompt_admin_func=prompt_admin_func, + record_prepare_func=record_prepare_func, + indent=indent, + context=context + ) + + # If we have an admin UID, add it to the last bulk record. + # It will be the one we added above. + if admin_uid is not None: + + self._apply_admin_uid( + bulk_add_records=bulk_add_records, + resource_uid=add_content.record_uid, + admin_uid=admin_uid + ) + + items_left -= 1 + + # If the record type is a PAM User, we don't need to go deeper. + # In the future we might need to change if PAM User becomes a branch and not a leaf. + # This is for safety reasons + if child_content.record_type != PAM_USER: + # Process the vertices that belong to the current vertex. + self._process_level( + current_vertex=vertex, + bulk_add_records=bulk_add_records, + bulk_convert_records=bulk_convert_records, + record_lookup_func=record_lookup_func, + prompt_func=prompt_func, + prompt_admin_func=prompt_admin_func, + record_prepare_func=record_prepare_func, + directory_info_func=directory_info_func, + record_cache=record_cache, + indent=indent + 1, + item_count=item_count, + items_left=items_left, + context=context + ) + self.logger.debug(f" finished processing {record_type}") + self.logger.debug(f" Finished current Vertex: {current_vertex.uid}, {current_content.name}") + + def _process_admin_user(self, + resource_vertex: DAGVertex, + resource_content: DiscoveryObject, + bulk_add_records: List[BulkRecordAdd], + bulk_convert_records: List[BulkRecordConvert], + record_lookup_func: Callable, + directory_info_func: Callable, + prompt_admin_func: Callable, + record_prepare_func: Callable, + indent: int = 0, + context: Optional[Any] = None) -> Optional[str]: + + # If the access_user is None, create an empty one. + # We will need this below when adding values to the fields. + if resource_content.access_user is None: + resource_content.access_user = DiscoveryUser() + + # Initialize a discovery object for the admin user. + # The PLACEHOLDER will be replaced after the admin user prompt. + + values = {} + for field in ["user", "password", "private_key", "dn", "database"]: + value = getattr(resource_content.access_user, field) + if value is None: + value = [] + else: + value = [value] + values[field] = value + + managed = [False] + if resource_content.access_user.managed is not None: + managed = [resource_content.access_user.managed] + + admin_content = DiscoveryObject( + uid="PLACEHOLDER", + object_type_value="users", + parent_record_uid=resource_content.record_uid, + record_type=PAM_USER, + id="PLACEHOLDER", + name="PLACEHOLDER", + description=resource_content.description + ", Administrator", + title=resource_content.title + ", Administrator", + item=DiscoveryUser( + user="PLACEHOLDER" + ), + fields=[ + RecordField(type="login", label="login", value=values["user"], required=True), + RecordField(type="password", label="password", value=values["password"], required=False), + RecordField(type="secret", label="privatePEMKey", value=values["private_key"], required=False), + RecordField(type="text", label="distinguishedName", value=values["dn"], required=False), + RecordField(type="text", label="connectDatabase", value=values["database"], required=False), + RecordField(type="checkbox", label="managed", value=managed, required=False), + ] + ) + + admin_acl = UserAcl.default() + admin_acl.is_admin = True + + # Prompt to add an admin user to this resource. + # We are not passing an ACL instance. + # We'll make it based on if the user is adding a new record or linking to an existing record. + admin_result = prompt_admin_func( + parent_vertex=resource_vertex, + content=admin_content, + acl=admin_acl, + bulk_convert_records=bulk_convert_records, + indent=indent, + context=context + ) + + # If the action is to ADD, replace the PLACEHOLDER data. + if admin_result.action == PromptActionEnum.ADD: + self.logger.debug("adding admin user") + + source = "local" + if resource_content.record_type == PAM_DIRECTORY: + source = resource_content.name + + admin_record_uid = admin_result.record_uid + + if admin_record_uid is None: + admin_content = admin_result.content + + # With the result, we can fill in information in the object item. + admin_content.item.user = admin_content.get_field_value("login") + admin_content.item.password = admin_content.get_field_value("password") + admin_content.item.private_key = admin_content.get_field_value("privatePEMKey") + admin_content.item.dn = admin_content.get_field_value("distinguishedName") + admin_content.item.database = admin_content.get_field_value("connectDatabase") + admin_content.item.managed = value_to_boolean( + admin_content.get_field_value("managed")) or False + admin_content.item.source = source + admin_content.name = admin_content.item.user + + self.logger.debug(f"added admin user from content") + + if admin_content.item.user is None or admin_content.item.user == "": + raise ValueError("The user name is missing or is blank. Cannot create the administrator user.") + + if admin_content.name is not None: + admin_content.description = (resource_content.description + ", User " + + admin_content.name) + + # We need to populate the id and uid of the content, now that we have data in the content. + self.populate_admin_content_ids(admin_content, resource_vertex) + + ad_user, ad_domain = split_user_and_domain(admin_content.item.user) + if ad_domain is not None and admin_content.item.source == LOCAL_USER: + self.logger.debug("The admin is an directory user, but the source is set to a local user") + + found_admin_record_uid = None + try: + found_admin_record_uid = self._find_admin_directory_user( + domain=ad_domain, + admin_acl=admin_acl, + directory_info_func=directory_info_func, + record_lookup_func=record_lookup_func, + context=context, + user=admin_content.item.user, + dn=admin_content.item.dn + ) + except DirectoryNotFoundException: + self.logger.debug(f" directory {source} was not found for admin user") + except UserNotFoundException: + self.logger.debug(f" directory user was not found in directory {source}") + + if found_admin_record_uid is not None: + self.logger.debug(" found directory user admin, connect to resource") + found_admin_vertices = self.infra.dag.search_content({"record_uid": found_admin_record_uid}) + if len(found_admin_vertices) == 1: + found_admin_vertices[0].belongs_to(resource_vertex, edge_type=EdgeType.KEY) + self.record_link.belongs_to(found_admin_record_uid, resource_content.record_uid, + acl=admin_acl) + return found_admin_record_uid + + # Does an admin vertex already exist for this user? + # This most likely user on the gateway, since without a resource record users can be discovered. + # If we did find it, get the content for the admin; we really want any existing record uid. + admin_vertex = self.infra.dag.get_vertex(admin_content.uid) + if admin_vertex is not None and admin_vertex.active is True and admin_vertex.has_data is True: + self.logger.debug("admin exists in the graph") + found_content = DiscoveryObject.get_discovery_object(admin_vertex) + admin_record_uid = found_content.record_uid + else: + self.logger.debug("admin does not exists in the graph") + + # If there is a record UID for the admin user, connect it. + if admin_record_uid is not None: + self.logger.debug("the admin has a record UID") + + # If the admin record does not belong to another resource, make this resource its owner. + if self.record_link.get_parent_record_uid(admin_record_uid) is None: + self.logger.debug("the admin does not belong to another resources, " + "setting it belong to this resource") + admin_acl.belongs_to = True + + admin_vertex.belongs_to(resource_vertex, edge_type=EdgeType.KEY) + self.record_link.belongs_to(admin_record_uid, resource_content.record_uid, acl=admin_acl) + else: + if admin_vertex is None: + self.logger.debug("creating an entry in the graph for the admin") + admin_vertex = self.infra.dag.add_vertex(uid=admin_content.uid, + name=admin_content.description) + + # Since this record does not exist, it will belong to the resource, + admin_acl.belongs_to = True + + # Connect the user vertex to the resource vertex. + # We need to add a KEY edge for the admin content stored on the DATA edge. + admin_vertex.belongs_to(resource_vertex, edge_type=EdgeType.KEY) + admin_vertex.add_data(admin_content) + + # The record will be a user record; admin_acl will not be None + self._prepare_record( + record_prepare_func=record_prepare_func, + bulk_add_records=bulk_add_records, + content=admin_content, + parent_content=resource_content, + vertex=admin_vertex, + context=context + ) + + self.record_link.discovery_belongs_to(admin_vertex, resource_vertex, acl=admin_acl) + + admin_record_uid = admin_content.record_uid + else: + self.logger.debug("add admin user from existing record") + + # If this is NOT existing directory user, we want to convert the record rotation setting to + # work with this gateway/controller. + # If it is a directory user, we just want link this record; no conversion. + if admin_result.is_directory_user is False: + + self.logger.debug("the admin user is NOT a directory user, convert record's rotation settings") + + # This is a pamUser record that may need to have the controller set. + # Add it to this queue to make sure the protobuf items are current. + parent_record_uid = resource_content.record_uid + if resource_content.object_type_value == "providers": + parent_record_uid = None + + bulk_convert_records.append( + BulkRecordConvert( + record_uid=admin_record_uid, + parent_record_uid=parent_record_uid + ) + ) + + # If this user record does not belong to another resource, make it belong to this one. + record_vertex = self.record_link.acl_has_belong_to_record_uid(admin_record_uid) + if record_vertex is None: + admin_acl.belongs_to = True + + # There is _prepare_record, the record exists. + # Needs to add to records linking. + else: + self.logger.debug("the admin user is a directory user") + + # Link the record UIDs. + # We might not have this user in discovery data. + # It might not belong to the resource; if so, it cannot be rotated. + # It only has is_admin in the ACL. + # self.record_link.belongs_to( + # admin_record_uid, + # record_uid, + # acl=admin_acl + # ) + + return admin_record_uid + + return None + + def _get_count(self, current_vertex: DAGVertex) -> int: + + """ + Get the number of vertices that have not been converted to record. + + This will recurse down the graph. + To be counted, the current vertex being evaluated, must ... + + * not have record UID. + * not be ignored either by flag or rule. + * not be auto added. + + To recurse down, the current vertex being evaluated, must ... + + * have a record UID + * not be ignored either by flag or rule. + + """ + + count = 0 + + for vertex in current_vertex.has_vertices(): + if not vertex.active: + continue + content = DiscoveryObject.get_discovery_object(vertex) + + # Add this record to the count, if no record UID, not ignoring, and we are not auto adding or + # ignoring from rules. + if (content.record_uid is None + and content.ignore_object is False + and content.action_rules_result != "add" + and content.action_rules_result != "ignore"): + count += 1 + + # Go deeper if there is a record UID, and we are not ignoring, and the rule result is not to ignore. + if ( + content.record_uid is not None + and content.ignore_object is False + and content.action_rules_result != "ignore"): + count += self._get_count(vertex) + + return count + + @property + def no_items_left(self): + return self._get_count(self.infra.get_root) == 0 + + def run(self, + prompt_func: Callable, + record_prepare_func: Callable, + smart_add: bool = False, + record_lookup_func: Optional[Callable] = None, + record_create_func: Optional[Callable] = None, + record_convert_func: Optional[Callable] = None, + prompt_confirm_add_func: Optional[Callable] = None, + prompt_admin_func: Optional[Callable] = None, + auto_add_result_func: Optional[Callable] = None, + directory_info_func: Optional[Callable] = None, + context: Optional[Any] = None, + record_cache: Optional[dict] = None, + force_quit: bool = False + ) -> BulkProcessResults: + """ + Process the discovery results. + + :param record_cache: A dictionary of record types to keys to record UID. + :param prompt_func: Function to call when the user needs to make a decision about an object. + :param smart_add: If we have resource cred, add the resource and the users. DEPRECATED + :param record_lookup_func: Function to look up a record by UID. + :param record_prepare_func: Function to call to prepare a record to be created. + :param record_create_func: Function to call to save the prepared records. + :param record_convert_func: Function to convert record to use this gateway. + :param prompt_confirm_add_func: Function to call if quiting and record have been added to queue. + :param prompt_admin_func: Function to prompt user for admin. + :param auto_add_result_func: Function to call after auto adding. Provided records to bulk add. + :param directory_info_func: Function to get users of a directory from vault records. + :param context: Context passed to the prompt and add function. These could be objects that are not in the scope + of the function. + :param force_quit: Used for testing. Throw a Quit exception after processing. + :return: + """ + sync_point = self.job.sync_point + if sync_point is None: + raise Exception("The job does not have a sync point for the graph.") + + # Get the root vertex, which has nothing we care about. + # But from the root, get the configuration vertex. + # There will be only one. + self.logger.debug(f"loading the graph at sync point {sync_point}") + self.infra.load(sync_point=sync_point) + if not self.infra.has_discovery_data: + raise NoDiscoveryDataException("There is no discovery data to process.") + + # If the graph is corrupted, delete the bad vertices. + # + if self.infra.dag.is_corrupt is True: + self.logger.debug("the graph is corrupt, deleting vertex") + for uid in self.infra.dag.corrupt_uids: + vertex = self.infra.dag.get_vertex(uid) + vertex.delete() + self.infra.dag.corrupt_uids = [] + self.logger.info("fixed the corrupted vertices") + + root = self.infra.get_root + configuration = root.has_vertices()[0] + + # If we have a record cache, attempt to find vertices where the content does not have the record UID set and + # then update them with cached records from the vault. + # This is done incase someone has manually created a record after discovery has been done. + if record_cache is not None: + self._update_with_record_uid( + record_cache=record_cache, + current_vertex=configuration, + ) + + # Store records that to be created and record where their protobuf settings need to be updated. + bulk_add_records = [] # type: List[BulkRecordAdd] + bulk_convert_records = [] # type: List[BulkRecordConvert] + + should_add_records = True + bulk_process_results = None + + # Pass an empty + if context is None: + context = {} + + # We need record linking and infra graphs in the context. + # We are adding admin users to check existing admin relationships and to see if AD user. + context["record_link"] = self.record_link + context["infra"] = self.infra + + try: + + self.logger.debug("# ####################################################################################") + self.logger.debug("# AUTO ADD ITEMS") + self.logger.debug("#") + self.logger.debug(f"smart add = {smart_add}") + + # Process the auto add entries first. + # There are no prompts. + self._process_auto_add_level( + current_vertex=configuration, + bulk_add_records=bulk_add_records, + bulk_convert_records=bulk_convert_records, + record_lookup_func=record_lookup_func, + record_prepare_func=record_prepare_func, + directory_info_func=directory_info_func, + record_cache=record_cache, + context=context) + + # If set, give the client a list of record that will be added. + # Can be used for displaying how many record are auto added. + if auto_add_result_func is not None: + auto_add_result_func(bulk_add_records=bulk_add_records) + + self.logger.debug("# ####################################################################################") + self.logger.debug("# PROMPT USER ITEMS") + self.logger.debug("#") + + # This is the total number of items that processing needs to process. + # We start with items_left equal to item_count. + item_count = self._get_count(configuration) + + self._process_level( + current_vertex=configuration, + bulk_add_records=bulk_add_records, + bulk_convert_records=bulk_convert_records, + record_lookup_func=record_lookup_func, + prompt_func=prompt_func, + prompt_admin_func=prompt_admin_func, + record_prepare_func=record_prepare_func, + directory_info_func=directory_info_func, + record_cache=record_cache, + indent=0, + item_count=item_count, + items_left=item_count, + context=context) + + # This mainly for testing. + # If throw and quit exception, so we can prompt the user. + if force_quit: + raise QuitException() + + except QuitException: + should_add_records = False + + # If we have record ready to be created, and the confirm prompt function was set, ask the user if they want + # to add the records. + if (len(bulk_add_records) > 0 and prompt_confirm_add_func is not None and + prompt_confirm_add_func(bulk_add_records) is True): + should_add_records = True + + modified_count = len(self.infra.dag.modified_edges) + self.logger.debug(f"quiting and there are {modified_count} modified edges.") + + # If we don't have a create function, then there is no way to add record. + if record_create_func is None: + should_add_records = False + + # We should add the record, and a method was passed in to create them; then add the records. + if should_add_records: + + self.logger.debug("# ####################################################################################") + self.logger.debug("# CREATE NEW RECORD") + self.logger.debug("#") + + # Save new records. + bulk_process_results = record_create_func( + bulk_add_records=bulk_add_records, + context=context + ) + self.logger.debug("# ####################################################################################") + + self.logger.debug("# ####################################################################################") + self.logger.debug("# CONVERT EXISTING RECORD") + self.logger.debug("#") + + # Update existing record to use this gateway. + record_convert_func( + bulk_convert_records=bulk_convert_records, + context=context + ) + self.logger.debug("# ####################################################################################") + else: + + self.logger.debug("# ####################################################################################") + self.logger.debug("# ROLLBACK GRAPH") + self.logger.debug("#") + + for record in bulk_add_records: + vertices = self.infra.dag.search_content({"record_uid": record.record_uid}) + for vertex in vertices: + self.logger.debug(f" * {record.title}, flagged") + vertex.skip_save = True + for record in bulk_convert_records: + vertices = self.infra.dag.search_content({"record_uid": record.record_uid}) + for vertex in vertices: + self.logger.debug(f" * {record.title}, flagged") + vertex.skip_save = True + + self.logger.debug("# ####################################################################################") + + self.logger.debug("# ####################################################################################") + self.logger.debug("# Save INFRASTRUCTURE graph") + self.logger.debug("#") + + # Disable delta save. + self.logger.debug(f"saving additions from process run") + self.infra.save(delta_graph=False) + self.logger.debug("# ####################################################################################") + + # Update the user service mapping + self.user_service.run(infra=self.infra) + + return bulk_process_results diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/record_link.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/record_link.py new file mode 100644 index 00000000..8f22afaa --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/record_link.py @@ -0,0 +1,499 @@ +from __future__ import annotations +import logging +from ..keeper_dag.dag_utils import get_connection, make_agent +from ..keeper_dag.dag import DAG, EdgeType +from ..keeper_dag.dag_types import PamGraphId, PamEndpoints, UserAcl, DiscoveryObject +import importlib +from typing import Any, Optional, List, TYPE_CHECKING + +if TYPE_CHECKING: + from ..keeper_dag.dag import DAGVertex + + +class RecordLink: + + def __init__(self, + record: Any, + logger: Optional[Any] = None, + debug_level: int = 0, + fail_on_corrupt: bool = True, + log_prefix: str = "GS Record Linking", + save_batch_count: int = 200, + agent: Optional[str] = None, + use_read_protobuf: bool = False, + use_write_protobuf: bool = True, + **kwargs): + + self.conn = get_connection(logger=logger, + use_read_protobuf=use_read_protobuf, + use_write_protobuf=use_write_protobuf, + **kwargs) + + # This will either be a KSM Record, or Commander KeeperRecord + self.record = record + self._dag = None + if logger is None: + logger = logging.getLogger() + self.logger = logger + self.log_prefix = log_prefix + self.debug_level = debug_level + self.save_batch_count = save_batch_count + + # Based on the connection type, use_write_protobuf might be set to False is True was passed. + # Use self.conn.use_write_protobuf; don't use passed in use_write_protobuf. + # If using protobuf to write, then use the endpoint. + self.write_endpoint = None + if self.conn.use_write_protobuf: + self.write_endpoint = PamEndpoints.PAM + + self.read_endpoint = None + if self.conn.use_read_protobuf: + self.read_endpoint = PamEndpoints.PAM + + self.agent = make_agent("record_linking") + if agent is not None: + self.agent += "; " + agent + + # Technically, since there is no encryption in this graph, there should be no corruption. + # Allow it to be set regardlessly. + self.fail_on_corrupt = fail_on_corrupt + + @property + def dag(self) -> DAG: + if self._dag is None: + + # Make sure this auto save is False. + # Since we don't have transactions, we want to save the record link if everything worked. + self._dag = DAG(conn=self.conn, + record=self.record, + write_endpoint=self.write_endpoint, + read_endpoint=self.read_endpoint, + graph_id=PamGraphId.PAM, + auto_save=False, + logger=self.logger, + debug_level=self.debug_level, + name="Record Linking", + fail_on_corrupt=self.fail_on_corrupt, + log_prefix=self.log_prefix, + save_batch_count=self.save_batch_count, + agent=self.agent) + sync_point = self._dag.load(sync_point=0) + self.logger.debug(f"the record linking sync point is {sync_point or 0}") + if not self.dag.has_graph: + self.dag.add_vertex(name=self.record.title, uid=self._dag.uid) + + return self._dag + + def close(self): + """ + Clean up resources held by this RecordLink instance. + Releases the DAG instance and connection to prevent memory leaks. + """ + if self._dag is not None: + self._dag = None + self.conn = None + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures cleanup.""" + self.close() + return False + + def __del__(self): + self.close() + + @property + def has_graph(self) -> bool: + return self.dag.has_graph + + def reload(self): + self._dag.load(sync_point=0) + + def get_record_link(self, uid: str) -> DAGVertex: + return self.dag.get_vertex(uid) + + def get_parent_uid(self, uid: str) -> Optional[str]: + """ + Get the vertex that the UID belongs to. + + This method will check the vertex ACL to see which edge has a True value for belongs_to. + If it is found, the record UID that the head points at will be returned. + If not found, None is returned. + """ + + vertex = self.dag.get_vertex(uid) + if vertex is not None: + for edge in vertex.edges: + if edge.edge_type == EdgeType.ACL: + content = edge.content_as_object(UserAcl) + if content.belongs_to is True: + return edge.head_uid + return None + + @staticmethod + def get_record_uid(discovery_vertex: DAGVertex, validate_record_type: Optional[str] = None) -> str: + """ + Get the record UID from the vertex + + """ + data = discovery_vertex.get_data() + if data is None: + raise Exception(f"The discovery vertex {discovery_vertex.uid} does not have a DATA edge. " + "Cannot get record UID.") + content = DiscoveryObject.get_discovery_object(discovery_vertex) + + if validate_record_type is not None: + if validate_record_type != content.record_type: + raise Exception(f"The vertex is not record type {validate_record_type}") + + if content.record_uid is not None: + return content.record_uid + raise Exception(f"The discovery vertex {discovery_vertex.uid} data does not have a populated record UID.") + + def add_configuration(self, discovery_vertex: DAGVertex): + """ + Add the configuration vertex to the DAG root. + + The configuration record UID will be the same as root UID. + + """ + + record_uid = self.get_record_uid(discovery_vertex) + record_vertex = self.dag.get_vertex(record_uid) + if record_vertex is None: + record_vertex = self.dag.add_vertex(uid=record_uid, name=discovery_vertex.name) + if not self.dag.get_root.has(record_vertex): + record_vertex.belongs_to_root(EdgeType.LINK) + + def discovery_belongs_to(self, discovery_vertex: DAGVertex, discovery_parent_vertex: DAGVertex, + acl: Optional[UserAcl] = None): + + """ + Link vault record using the vertices from discovery. + + If a link already exists, no additional link will be created. + """ + + try: + record_uid = self.get_record_uid(discovery_vertex) + except Exception as err: + self.logger.warning(f"The discovery vertex is missing a record uid, cannot connect record: {err}") + return + + # If the parent_vertex is the root, then don't get the record UID from the data. + # The root vertex will have no data, and the record UID is the same as the vertex UID. + if discovery_parent_vertex.uid == self.dag.uid: + parent_record_uid = discovery_parent_vertex.uid + else: + try: + parent_record_uid = self.get_record_uid(discovery_parent_vertex) + except Exception as err: + self.logger.warning("The discovery parent vertex is missing a record uid, cannot connect record: " + f"{err}") + return + + self.belongs_to( + record_uid=record_uid, + parent_record_uid=parent_record_uid, + acl=acl, + record_name=discovery_vertex.name, + parent_record_name=discovery_parent_vertex.name + ) + + def belongs_to(self, record_uid: str, parent_record_uid: str, acl: Optional[UserAcl] = None, + record_name: Optional[str] = None, parent_record_name: Optional[str] = None): + + """ + Link vault records using record UIDs. + + If a link already exists, no additional link will be created. + """ + + # Get the record's vertices. + # If a vertex does not exist, then add the vertex using the record UID + record_vertex = self.dag.get_vertex(record_uid) + if record_vertex is None: + self.logger.debug(f"adding record linking vertex for record UID {record_uid} ({record_name})") + record_vertex = self.dag.add_vertex(uid=record_uid, name=record_name) + + parent_record_vertex = self.dag.get_vertex(parent_record_uid) + if parent_record_vertex is None: + self.logger.debug(f"adding record linking vertex for parent record UID {parent_record_uid}") + parent_record_vertex = self.dag.add_vertex(uid=parent_record_uid, name=parent_record_name) + + self.logger.debug(f"record UID {record_vertex.uid} belongs to {parent_record_vertex.uid} " + f"({parent_record_name})") + + # By default, the LINK edge will link records. + # If ACL information was passed in, use the ACL edge. + edge_type = EdgeType.LINK + if acl is not None: + edge_type = EdgeType.ACL + + # Get the current edge if it exists. + # We need to create it if it does not exist and only add it if the ACL changed. + existing_edge = record_vertex.get_edge(parent_record_vertex, edge_type=edge_type) + add_edge = True + if existing_edge is not None and existing_edge.active is True: + if edge_type == EdgeType.ACL: + content = existing_edge.content_as_object(UserAcl) # type: UserAcl + if content.model_dump_json() == acl.model_dump_json(): + add_edge = False + else: + add_edge = False + + if add_edge: + self.logger.debug(f" added {edge_type} edge") + record_vertex.belongs_to(parent_record_vertex, edge_type=edge_type, content=acl) + + def get_acl(self, record_uid: str, parent_record_uid: str, record_name: Optional[str] = None, + parent_record_name: Optional[str] = None) -> Optional[UserAcl]: + + # Get the record's vertices. + # If a vertex does not exist, then add the vertex using the record UID + record_vertex = self.dag.get_vertex(record_uid) + if record_vertex is None: + self.logger.debug(f"adding record linking vertex for record UID {record_uid} ({record_name})") + record_vertex = self.dag.add_vertex(uid=record_uid, name=record_name) + + parent_record_vertex = self.dag.get_vertex(parent_record_uid) + if parent_record_vertex is None: + self.logger.debug(f"adding record linking vertex for parent record UID {parent_record_uid}") + parent_record_vertex = self.dag.add_vertex(uid=parent_record_uid, name=parent_record_name) + + acl_edge = record_vertex.get_edge(parent_record_vertex, edge_type=EdgeType.ACL) + if acl_edge is None: + return None + + return acl_edge.content_as_object(UserAcl) + + def acl_has_belong_to_vertex(self, discovery_vertex: DAGVertex) -> Optional[DAGVertex]: + """ + Get the resource vertex for this user vertex that handles rotation, using the user's infrastructure vertex. + """ + + record_uid = self.get_record_uid(discovery_vertex, "pamUser") + if record_uid is None: + return None + + return self.acl_has_belong_to_record_uid(record_uid) + + def acl_has_belong_to_record_uid(self, record_uid: str) -> Optional[DAGVertex]: + + """ + Get the resource vertex for this user vertex that handles rotation. using the user's record UID. + """ + + record_vertex = self.dag.get_vertex(record_uid) + if record_vertex is None: + return None + for edge in record_vertex.edges: + if edge.edge_type != EdgeType.ACL: + continue + content = edge.content_as_object(UserAcl) + if content.belongs_to is True: + return self.dag.get_vertex(edge.head_uid) + return None + + def get_parent_record_uid(self, record_uid: str) -> Optional[str]: + """ + Get the parent record uid. + + Check the ACL edges for the one where belongs_to is True + If there is a LINK edge that leads to the parent. + """ + + record_vertex = self.dag.get_vertex(record_uid) + if record_vertex is None: + return None + for edge in record_vertex.edges: + if edge.edge_type == EdgeType.ACL: + content = edge.content_as_object(UserAcl) # type: UserAcl + if content.belongs_to: + return edge.head_uid + elif edge.edge_type == EdgeType.LINK: + return edge.head_uid + return None + + def get_child_record_uids(self, record_uid: str) -> List[str]: + """ + Get a list of child record for this parent. + + The list contains any parent that this record uid has a LINK or ACL edge to. + """ + + record_vertex = self.dag.get_vertex(record_uid) + if record_vertex is None: + self.logger.debug(f"could not get the parent record for {record_uid}") + return [] + + record_uids = [] + self.logger.debug(f"has {record_vertex.has_vertices()}") + for child_vertex in record_vertex.has_vertices(EdgeType.ACL): + record_uids.append(child_vertex.uid) + for child_vertex in record_vertex.has_vertices(EdgeType.LINK): + record_uids.append(child_vertex.uid) + + return record_uids + + def get_parent_record_uids(self, record_uid: str) -> List[str]: + """ + Get a list of parent record this child record belongs to. + + The list contains any parent that this record uid has a LINK or ACL edge to. + """ + + record_vertex = self.dag.get_vertex(record_uid) + if record_vertex is None: + self.logger.debug(f"could not get the child record for {record_uid}") + return [] + + record_uids = [] + for vertex in record_vertex.belongs_to_vertices(): + edge = vertex.get_edge(record_vertex, EdgeType.ACL) + if edge is None: + edge = vertex.get_edge(record_vertex, EdgeType.LINK) + if edge is not None: + record_uids.append(record_vertex.uid) + return record_uids + + def get_admin_record_uid(self, record_uid: str) -> Optional[str]: + """ + Get the record that admins this resource record. + + """ + + record_vertex = self.dag.get_vertex(record_uid) + if record_vertex is not None: + for vertex in record_vertex.has_vertices(): + for edge in vertex.edges: + if edge.head_uid != record_vertex.uid: + continue + if edge.edge_type == EdgeType.ACL: + content = edge.content_as_object(UserAcl) # type: UserAcl + if content.is_admin is True: + return vertex.uid + return None + + def discovery_disconnect_from(self, discovery_vertex: DAGVertex, discovery_parent_vertex: DAGVertex): + record_uid = self.get_record_uid(discovery_vertex) + parent_record_uid = self.get_record_uid(discovery_parent_vertex) + self.disconnect_from(record_uid=record_uid, parent_record_uid=parent_record_uid) + + def disconnect_from(self, record_uid: str, parent_record_uid: str): + record_vertex = self.dag.get_vertex(record_uid) + parent_record_vertex = self.dag.get_vertex(parent_record_uid) + + # Check if we got vertex for the record UIDs. + # Log info if we didn't. + # Since we are disconnecting, we are not going to treat this as a fatal error. + if record_vertex is None: + self.logger.info(f"for record linking, could not find the vertex for record UID {record_uid}." + f" cannot disconnect from parent vertex for record UID {parent_record_uid}") + return + if parent_record_vertex is None: + self.logger.info(f"for record linking, could not find the parent vertex for record UID {parent_record_uid}." + f" cannot disconnect the child vertex for record UID {record_uid}") + return + + parent_record_vertex.disconnect_from(record_vertex) + + @staticmethod + def delete(vertex: DAGVertex): + if vertex is not None: + vertex.delete() + + def save(self): + + self.logger.debug("DISCOVERY COMMON RECORD LINKING GRAPH SAVE CALLED") + if self.dag.has_graph: + self.logger.debug("saving the record linking.") + self.dag.save(delta_graph=False) + else: + self.logger.debug("the record linking graph does not contain any data, was not saved.") + + def to_dot(self, graph_format: str = "svg", show_version: bool = True, show_only_active_vertices: bool = True, + show_only_active_edges: bool = True, graph_type: str = "dot"): + + try: + mod = importlib.import_module("graphviz") + except ImportError: + raise Exception("Cannot to_dot(), graphviz module is not installed.") + + dot = getattr(mod, "Digraph")(comment=f"DAG for Record Linking", format=graph_format) + + if graph_type == "dot": + dot.attr(rankdir='RL') + elif graph_type == "twopi": + dot.attr(layout="twopi") + dot.attr(ranksep="10") + dot.attr(ratio="auto") + else: + dot.attr(layout=graph_type) + + self.logger.debug(f"have {len(self.dag.all_vertices)} vertices") + for v in self.dag.all_vertices: + if show_only_active_vertices is True and v.active is False: + continue + + tooltip = "" + + for edge in v.edges: + + color = "grey" + style = "solid" + + # To reduce the number of edges, only show the active edges + if edge.active is True: + color = "black" + style = "bold" + elif show_only_active_edges is True: + continue + + # If the vertex is not active, gray out the DATA edge + if edge.edge_type == EdgeType.DATA and v.active is False: + color = "grey" + + if edge.edge_type == EdgeType.DELETION: + style = "dotted" + + edge_tip = "" + if edge.edge_type == EdgeType.ACL and v.active is True: + content = edge.content_as_dict + if content.get("is_admin") is True: + color = "red" + if content.get("belongs_to") is True: + if color == "red": + color = "purple" + else: + color = "blue" + + tooltip += f"TO {edge.head_uid}\\n" + for k, val in content.items(): + tooltip += f" * {k}={val}\\n" + tooltip += f"--------------------\\n\\n" + + label = DAG.EDGE_LABEL.get(edge.edge_type) + if label is None: + label = "UNK" + if edge.path is not None and edge.path != "": + label += f"\\npath={edge.path}" + if show_version is True: + label += f"\\nv={edge.version}" + + # tail, head (arrow side), label, ... + dot.edge(v.uid, edge.head_uid, label, style=style, fontcolor=color, color=color, tooltip=edge_tip) + + shape = "ellipse" + fillcolor = "white" + color = "black" + if v.active is False: + fillcolor = "grey" + + label = f"uid={v.uid}" + dot.node(v.uid, label, color=color, fillcolor=fillcolor, style="filled", shape=shape, tooltip=tooltip) + + return dot diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/rule.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/rule.py new file mode 100644 index 00000000..c450f12d --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/rule.py @@ -0,0 +1,383 @@ +from __future__ import annotations +from .dag_types import (RuleTypeEnum, RuleItem, ActionRuleSet, ActionRuleItem, ScheduleRuleSet, ComplexityRuleSet, + Statement, RuleActionEnum) +from .dag_utils import value_to_boolean, get_connection, make_agent +from .dag import DAG, EdgeType +from .exceptions import DAGException +from .dag_types import PamGraphId +from time import time +import base64 +import os +from typing import Any, List, Optional, Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from .dag_types import DiscoveryObject + + +class Rules: + + DATA_PATH = "rules" + RULE_ITEM_TYPE_MAP = { + "ActionRuleItem": RuleTypeEnum.ACTION, + "ScheduleRuleItem": RuleTypeEnum.SCHEDULE, + "ComplexityRuleItem": RuleTypeEnum.COMPLEXITY + } + RULE_TYPE_TO_SET_MAP = { + RuleTypeEnum.ACTION: ActionRuleSet, + RuleTypeEnum.SCHEDULE: ScheduleRuleSet, + RuleTypeEnum.COMPLEXITY: ComplexityRuleSet + } + + RULE_FIELDS = { + # Attributes the records + "recordType": {"type": str}, + "parentRecordType": {"type": str}, + "recordTitle": {"type": str}, + "recordNotes": {"type": str}, + "recordDesc": {"type": str}, + "parentUid": {"type": str}, + + # Record fields + "login": {"type": str}, + "password": {"type": str}, + "privatePEMKey": {"type": str}, + "distinguishedName": {"type": str}, + "connectDatabase": {"type": str}, + "managed": {"type": bool, "default": False}, + "hostName": {"type": str}, + "port": {"type": float, "default": 0}, + "operatingSystem": {"type": str}, + "instanceName": {"type": str}, + "instanceId": {"type": str}, + "providerGroup": {"type": str}, + "providerRegion": {"type": str}, + "databaseId": {"type": str}, + "databaseType": {"type": str}, + "useSSL": {"type": bool, "default": False}, + "domainName": {"type": str}, + "directoryId": {"type": str}, + "directoryType": {"type": str}, + + # Progmatically added + "ip": {"type": str}, + } + + BREAK_OUT = { + "pamHostname": { + "hostName": "hostName", + "port": "port" + } + } + + # If creating an ignore role, these fields are used in the rule. + RECORD_FIELD = { + "pamMachine": ["pamHostname"], + "pamDatabase": ["pamHostname", "databaseType"], + "pamDirectory": ["pamHostname", "directoryType"], + "pamUser": ["parentUid", "login", "distinguishedName"], + } + + OBJ_ATTR = { + "parentUid": "parent_record_uid" + } + + def __init__(self, record: Any, logger: Optional[Any] = None, debug_level: int = 0, fail_on_corrupt: bool = True, + agent: Optional[str] = None, **kwargs): + + self.conn = get_connection(**kwargs) + + # This will either be a KSM Record, or Commander KeeperRecord + self.record = record + self._dag = None + self.logger = logger + self.debug_level = debug_level + self.fail_on_corrupt = fail_on_corrupt + + self.agent = make_agent("rules") + if agent is not None: + self.agent += "; " + agent + + @property + def dag(self) -> DAG: + if self._dag is None: + + # Turn auto_save on after the DAG has been created. + # No need to call it six times in a row to initialize it. + self._dag = DAG(conn=self.conn, + record=self.record, + # endpoint=PamEndpoints.DISCOVERY_RULES, + graph_id=PamGraphId.DISCOVERY_RULES, + auto_save=False, + logger=self.logger, + debug_level=self.debug_level, + fail_on_corrupt=self.fail_on_corrupt, + agent=self.agent) + self._dag.load() + + # Has the status been initialized? + if not self._dag.has_graph: + for rule_type_enum in Rules.RULE_TYPE_TO_SET_MAP: + rules = self._dag.add_vertex() + rules.belongs_to_root( + EdgeType.KEY, + path=rule_type_enum.value + ) + content = Rules.RULE_TYPE_TO_SET_MAP[rule_type_enum]() + rules.add_data( + content=content, + ) + self._dag.save() + + # The graph exists now, turn on the auto_save. + self._dag.auto_save = True + return self._dag + + def close(self): + """ + Clean up resources held by this Rules instance. + Releases the DAG instance and connection to prevent memory leaks. + """ + + try: + if hasattr(self, "_dag"): + self.conn = None + del self._dag + if hasattr(self, "conn"): + self.conn = None + del self.conn + if hasattr(self, "record"): + self.conn = None + del self.conn + except (Exception,): + pass + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures cleanup.""" + self.close() + return False + + def __del__(self): + self.close() + + + @staticmethod + def data_path(rule_type: RuleTypeEnum): + return f"/{rule_type.value}" + + def get_ruleset(self, rule_type: RuleTypeEnum): + path = self.data_path(rule_type) + rule_json = self.dag.walk_down_path(path).content_as_str + if rule_json is None: + raise DAGException("Could not get the status data from the DAG.") + rule_set_class = Rules.RULE_TYPE_TO_SET_MAP[rule_type] + return rule_set_class.model_validate_json(rule_json) + + def set_ruleset(self, rule_type: RuleTypeEnum, rules: List[Rules]): + path = self.data_path(rule_type) + self.dag.walk_down_path(path).add_data( + content=rules, + ) + # Auto save should save the data + + def _rule_transaction(self, func: Callable, rule: Optional[RuleItem] = None): + rule_type = rule.__class__.__name__ + rule_type_enum = Rules.RULE_ITEM_TYPE_MAP.get(rule_type) + if rule_type_enum is None: + raise ValueError("rule is not a known rule instance") + + # Get the ruleset and the rule list for the type + ruleset = self.get_ruleset(rule_type_enum) + + # Call the specialized code + rules = func( + r=rule, + rs=ruleset.rules + ) + + # Sort the rule by priority in asc order. + ruleset.rules = list(sorted(rules, key=lambda x: x.priority)) + self.set_ruleset(rule_type_enum, ruleset) + + def add_rule(self, rule: RuleItem) -> RuleItem: + + if rule.rule_id is None: + rule.rule_id = "RULE" + base64.urlsafe_b64encode(os.urandom(8)).decode().rstrip('=') + if rule.added_ts is None: + rule.added_ts = int(time()) + + def _add_rule(r: RuleItem, rs: List[RuleItem]): + rs.append(r) + return rs + + self._rule_transaction( + rule=rule, + func=_add_rule + ) + + return rule + + def update_rule(self, rule: RuleItem) -> RuleItem: + + def _update_rule(r: RuleItem, rs: List[RuleItem]): + new_rule_list = [] + for _r in rs: + if _r.rule_id == r.rule_id: + new_rule_list.append(r) + else: + new_rule_list.append(_r) + return new_rule_list + + self._rule_transaction( + rule=rule, + func=_update_rule + ) + + return rule + + def remove_rule(self, rule: RuleItem): + + def _remove_rule(r: RuleItem, rs: List[RuleItem]): + new_rule_list = [] + for _r in rs: + if _r.rule_id != r.rule_id: + new_rule_list.append(_r) + return new_rule_list + + self._rule_transaction( + rule=rule, + func=_remove_rule + ) + + def remove_all(self, rule_type: RuleTypeEnum): + + def _remove_all_rules(r: Any, rs: List[RuleItem]): + return [] + + # _rule_transaction determines the graph vertex from Rule class type + fake_rule = None + if rule_type == RuleTypeEnum.ACTION: + fake_rule = ActionRuleItem(statement=[]) + else: + raise ValueError("rule type not supported with remove_all") + + self._rule_transaction( + rule=fake_rule, + func=_remove_all_rules + ) + + def rule_list(self, rule_type: RuleTypeEnum, search: Optional[str] = None) -> List[RuleItem]: + rule_list = [] + for rule_item in self.get_ruleset(rule_type).rules: + if search is not None and rule_item.search(search) is False: + continue + rule_list.append(rule_item) + + return rule_list + + def get_rule_item(self, rule_type: RuleTypeEnum, rule_id: str) -> Optional[RuleItem]: + for rule_item in self.rule_list(rule_type=rule_type): + if rule_item.rule_id == rule_id: + return rule_item + return None + + @staticmethod + def make_action_rule_from_content(content: DiscoveryObject, action: RuleActionEnum, priority: Optional[int] = None, + case_sensitive: bool = True, + shared_folder_uid: Optional[str] = None) -> ActionRuleItem: + + if action == RuleActionEnum.IGNORE: + priority = -1 + + record_fields = Rules.RECORD_FIELD.get(content.record_type) + if record_fields is None: + raise ValueError(f"Record type {content.record_type} does not have fields maps.") + + statements = [ + Statement(field="recordType", operator="==", value=content.record_type) + ] + + for field_label in record_fields: + if field_label in Rules.OBJ_ATTR: + attr = Rules.OBJ_ATTR[field_label] + if not hasattr(content, attr): + raise Exception(f"Discovery object is missing attribute {attr}") + value = getattr(content, attr) + statements.append( + Statement(field=field_label, operator="==", value=value) + ) + else: + for field in content.fields: + label = field.label + if field_label != label: + continue + + value = field.value + if value is None or len(value) == 0: + continue + value = value[0] + + if label in Rules.BREAK_OUT: + for key in Rules.BREAK_OUT[label]: + key_value = value.get(key) + if key_value is None: + continue + statements.append( + Statement(field=key, operator="==", value=key_value) + ) + else: + statements.append( + Statement(field=label, operator="==", value=value) + ) + + return ActionRuleItem( + enabled=True, + priority=priority, + case_sensitive=case_sensitive, + statement=statements, + action=action, + shared_folder_uid=shared_folder_uid + ) + + @staticmethod + def make_action_rule_statement_str(statement: List[Statement]) -> str: + statement_str = "" + for item in statement: + if statement_str != "": + statement_str += " and " + statement_str += item.field + " " + item.operator + " " + field_type = Rules.RULE_FIELDS.get(item.field).get("type") + if field_type is None: + raise ValueError("Unknown field in rule") + + values = item.value + new_values = [] + if item.operator != "in": + values = [values] + + for value in values: + if field_type is str: + new_value = f"'{value}'" + elif field_type is bool: + if value_to_boolean(value) is True: + new_value = "true" + else: + new_value = "false" + elif field_type is float: + if int(value) == value: + new_value = str(int(value)) + else: + new_value = str(value) + else: + raise ValueError("Cannot determine the field type for rule statement.") + + new_values.append(new_value) + + if item.operator == "in": + statement_str += "[" + ", ".join(new_values) + "]" + else: + statement_str += new_values[0] + return statement_str diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/__init__.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/__init__.py new file mode 100644 index 00000000..b2b36405 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/__init__.py @@ -0,0 +1,56 @@ +from __future__ import annotations +import logging +from ..dag_types import SyncQuery, Ref, RefType, DAGData, DataPayload, EdgeType +from ....proto import GraphSync_pb2 as gs_pb2 +from pydantic import BaseModel +from typing import Optional, Union, List, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + Logger = Union[logging.RootLogger, logging.Logger] + + +class SyncResult(BaseModel): + sync_point: int = 0 + data: List[DAGData] = [] + has_more: bool = False + + +class DataStructBase: + + def __init__(self, + logger: Optional[Logger] = None): + + if logger is None: + logger = logging.getLogger() + self.logger = logger + + def sync_query(self, + stream_id: str, + sync_point: int = 0, + graph_id: Optional[int] = None) -> Union[SyncQuery, gs_pb2.GraphSyncQuery]: + pass + + @staticmethod + def origin_ref(origin_uid: str, + name: str) -> Union[Ref, gs_pb2.GraphSyncRef]: + pass + + def data(self, + data_type: EdgeType, + tail_uid: str, + content: Optional[bytes] = None, + head_uid: Optional[str] = None, + tail_name: Optional[str] = None, + head_name: Optional[str] = None, + tail_ref_type: Optional[RefType] = None, + head_ref_type: Optional[RefType] = None, + path: Optional[str] = None) -> Union[DAGData,gs_pb2.GraphSyncData]: + + pass + + @staticmethod + def payload(origin_ref: Union[Ref, gs_pb2.GraphSyncRef], + data_list: List[Union[DAGData, gs_pb2.GraphSyncData]], + graph_id: Optional[int] = None) -> Union[DataPayload, gs_pb2.GraphSyncAddDataRequest]: + + pass diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/default.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/default.py new file mode 100644 index 00000000..6f82cf25 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/default.py @@ -0,0 +1,82 @@ +from __future__ import annotations +from . import DataStructBase +from ..dag_types import SyncQuery, Ref, RefType, DAGData, DataPayload, EdgeType, SyncData +from ..dag_crypto import generate_random_bytes, generate_uid_str +from .... import utils +import base64 +from typing import Optional, List + + +class DataStruct(DataStructBase): + + def sync_query(self, + stream_id: str, + sync_point: int = 0, + graph_id: Optional[int] = None) -> SyncQuery: + + return SyncQuery( + streamId=stream_id, + deviceId=base64.urlsafe_b64encode(generate_random_bytes(16)).decode(), + syncPoint=sync_point, + graphId=graph_id + ) + + @staticmethod + def get_sync_result(results: bytes) -> SyncData: + res = SyncData.model_validate_json(results) + return res + + @staticmethod + def origin_ref(origin_ref_value: bytes, + name: str) -> Ref: + + return Ref( + type=RefType.DEVICE, + value=generate_uid_str(uid_bytes=origin_ref_value), + name=name + ) + + def data(self, + data_type: EdgeType, + tail_uid: str, + content: Optional[bytes] = None, + head_uid: Optional[str] = None, + tail_name: Optional[str] = None, + head_name: Optional[str] = None, + tail_ref_type: Optional[RefType] = None, + head_ref_type: Optional[RefType] = None, + path: Optional[str] = None) -> DAGData: + + if content is not None: + content = utils.base64_url_encode(content) + + return DAGData( + type=data_type, + content=content, + # tail point at this vertex, so it uses this vertex's uid. + ref=Ref( + type=tail_ref_type, + value=tail_uid, + name=tail_name + ), + # Head, the arrowhead, points at the vertex this vertex belongs to, the parent. + # Apparently, for DATA edges, the parentRef is allowed to be None. + # Doesn't hurt to send it. + parentRef=Ref( + type=head_ref_type, + value=head_uid, + name=head_name + ), + path=path + ) + + @staticmethod + def payload(origin_ref: Ref, + data_list: List[DAGData], + graph_id: Optional[int] = None) -> DataPayload: + + return DataPayload( + origin=origin_ref, + dataList=data_list, + graphId=graph_id + ) diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/protobuf.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/protobuf.py new file mode 100644 index 00000000..3581ab21 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/protobuf.py @@ -0,0 +1,193 @@ +from __future__ import annotations +import logging +from typing import Optional, List, Union + +from ....proto import GraphSync_pb2 as gs_pb2 +from ..dag_types import RefType, EdgeType, Ref, SyncData, SyncDataItem, SyncQuery, DataPayload, DAGData +from .. import dag_crypto + + +class DataStructBase: + + def __init__(self, + logger: Optional[logging.Logger] = None): + + if logger is None: + logger = logging.getLogger() + self.logger = logger + + def sync_query(self, + stream_id: str, + sync_point: int = 0, + graph_id: Optional[int] = None) -> Union[SyncQuery, gs_pb2.GraphSyncQuery]: + pass + + @staticmethod + def origin_ref(origin_uid: str, + name: str) -> Union[Ref, gs_pb2.GraphSyncRef]: + pass + + def data(self, + data_type: EdgeType, + tail_uid: str, + content: Optional[bytes] = None, + head_uid: Optional[str] = None, + tail_name: Optional[str] = None, + head_name: Optional[str] = None, + tail_ref_type: Optional[RefType] = None, + head_ref_type: Optional[RefType] = None, + path: Optional[str] = None) -> Union[DAGData,gs_pb2.GraphSyncData]: + + pass + + @staticmethod + def payload(origin_ref: Union[Ref, gs_pb2.GraphSyncRef], + data_list: List[Union[DAGData, gs_pb2.GraphSyncData]], + graph_id: Optional[int] = None) -> Union[DataPayload, gs_pb2.GraphSyncAddDataRequest]: + + pass + + +class DataStruct(DataStructBase): + + # https://github.com/Keeper-Security/keeperapp-protobuf/blob/master/GraphSync.proto + + REF_TO_PB_MAP = { + RefType.GENERAL: gs_pb2.RefType.RFT_GENERAL, + RefType.USER: gs_pb2.RefType.RFT_USER, + RefType.DEVICE: gs_pb2.RefType.RFT_DEVICE, + RefType.REC: gs_pb2.RefType.RFT_REC, + RefType.FOLDER: gs_pb2.RefType.RFT_FOLDER, + RefType.TEAM: gs_pb2.RefType.RFT_TEAM, + RefType.ENTERPRISE: gs_pb2.RefType.RFT_ENTERPRISE, + RefType.PAM_DIRECTORY: gs_pb2.RefType.RFT_PAM_DIRECTORY, + RefType.PAM_MACHINE: gs_pb2.RefType.RFT_PAM_MACHINE, + RefType.PAM_DATABASE: gs_pb2.RefType.RFT_PAM_DATABASE, + RefType.PAM_USER: gs_pb2.RefType.RFT_PAM_USER, + RefType.PAM_NETWORK: gs_pb2.RefType.RFT_PAM_NETWORK, + RefType.PAM_BROWSER: gs_pb2.RefType.RFT_PAM_BROWSER, + RefType.CONNECTION: gs_pb2.RefType.RFT_CONNECTION, + RefType.WORKFLOW: gs_pb2.RefType.RFT_WORKFLOW, + RefType.NOTIFICATION: gs_pb2.RefType.RFT_NOTIFICATION, + RefType.USER_INFO: gs_pb2.RefType.RFT_USER_INFO, + RefType.TEAM_INFO: gs_pb2.RefType.RFT_TEAM_INFO, + RefType.ROLE: gs_pb2.RefType.RFT_ROLE + } + + DATA_TO_PB_MAP = { + EdgeType.DATA: gs_pb2.GraphSyncDataType.GSE_DATA, + EdgeType.KEY: gs_pb2.GraphSyncDataType.GSE_KEY, + EdgeType.LINK: gs_pb2.GraphSyncDataType.GSE_LINK, + EdgeType.ACL: gs_pb2.GraphSyncDataType.GSE_ACL, + EdgeType.DELETION: gs_pb2.GraphSyncDataType.GSE_DELETION + } + + PB_TO_REF_MAP = {v: k for k, v in REF_TO_PB_MAP.items()} + PB_TO_DATA_MAP = {v: k for k, v in DATA_TO_PB_MAP.items()} + + def sync_query(self, + stream_id: str, + sync_point: int = 0, + graph_id: Optional[int] = None) -> gs_pb2.GraphSyncQuery: + + return gs_pb2.GraphSyncQuery( + streamId=dag_crypto.urlsafe_str_to_bytes(stream_id), + origin=dag_crypto.generate_random_bytes(16), + syncPoint=sync_point, + + # Use the default from KRouter; currently 500 + maxCount=0 + ) + + @staticmethod + def get_sync_result(results: bytes) -> SyncData: + + try: + result = gs_pb2.GraphSyncResult() + result.ParseFromString(results) + except Exception as err: + raise Exception(f"Could not parse the GraphSyncResult message: {err}") + + message = gs_pb2.GraphSyncResult() + message.ParseFromString(results) + + data_list: List[SyncDataItem] = [] + for item in message.data: + data_list.append( + SyncDataItem( + type=DataStruct.PB_TO_DATA_MAP.get(item.data.type), + # content=bytes_to_str(item.data.content), + content=item.data.content, + content_is_base64=False, + ref=Ref( + type=DataStruct.PB_TO_REF_MAP.get(item.data.ref.type), + value=dag_crypto.bytes_to_urlsafe_str(item.data.ref.value), + ), + parentRef=Ref( + type=DataStruct.PB_TO_REF_MAP.get(item.data.parentRef.type), + value=dag_crypto.bytes_to_urlsafe_str(item.data.parentRef.value) + ), + path=item.data.path + ) + ) + + return SyncData( + syncPoint=message.syncPoint, + data=data_list, + hasMore=message.hasMore + ) + + @staticmethod + def origin_ref(origin_ref_value: bytes, + name: str) -> gs_pb2.GraphSyncRef: + + return gs_pb2.GraphSyncRef( + type=gs_pb2.RefType.RFT_DEVICE, + value=origin_ref_value, + name=name + ) + + def data(self, + data_type: EdgeType, + tail_uid: str, + content: Optional[bytes] = None, + head_uid: Optional[str] = None, + tail_name: Optional[str] = None, + head_name: Optional[str] = None, + tail_ref_type: Optional[RefType] = None, + head_ref_type: Optional[RefType] = None, + path: Optional[str] = None) -> gs_pb2.GraphSyncData: + + if isinstance(tail_uid, str): + tail_uid = dag_crypto.urlsafe_str_to_bytes(tail_uid) + if head_uid is not None and isinstance(head_uid, str): + head_uid = dag_crypto.urlsafe_str_to_bytes(head_uid) + + return gs_pb2.GraphSyncData( + type=DataStruct.DATA_TO_PB_MAP.get(data_type), + content=content, + # tail point at this vertex, so it uses this vertex's uid. + ref=gs_pb2.GraphSyncRef( + type=DataStruct.REF_TO_PB_MAP.get(tail_ref_type), + value=tail_uid, + name=tail_name + ), + # Head, the arrowhead, points at the vertex this vertex belongs to, the parent. + # Apparently, for DATA edges, the parentRef is allowed to be None. + # Doesn't hurt to send it. + parentRef=gs_pb2.GraphSyncRef( + type=DataStruct.REF_TO_PB_MAP.get(head_ref_type), + value=head_uid, + name=head_name + ), + path=path + ) + + @staticmethod + def payload(origin_ref: gs_pb2.GraphSyncRef, + data_list: List[gs_pb2.GraphSyncData], + graph_id: Optional[int] = None) -> gs_pb2.GraphSyncAddDataRequest: + + return gs_pb2.GraphSyncAddDataRequest( + origin=origin_ref, + data=data_list) diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/user_service.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/user_service.py new file mode 100644 index 00000000..bbf95cdf --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/user_service.py @@ -0,0 +1,678 @@ +from __future__ import annotations +import logging +from .constants import PAM_MACHINE, PAM_USER, PAM_DIRECTORY, DOMAIN_USER_CONFIGS +from .dag_utils import get_connection, user_in_lookup, user_check_list, make_agent +from .dag_types import DiscoveryObject, ServiceAcl, FactsNameUser +from .infrastructure import Infrastructure +from .dag import DAG, EdgeType +from .dag_types import PamGraphId +import importlib +from typing import Any, Optional, List, TYPE_CHECKING + +if TYPE_CHECKING: + from .dag_vertex import DAGVertex + from .dag_edge import DAGEdge + + +# TODO: Refactor this code; we can make this smaller since method basically do the same functions, just different +# attributes. +class UserService: + + def __init__(self, record: Any, logger: Optional[Any] = None, history_level: int = 0, + debug_level: int = 0, fail_on_corrupt: bool = True, log_prefix: str = "GS Services/Tasks", + save_batch_count: int = 200, agent: Optional[str] = None, + **kwargs): + + self.conn = get_connection(**kwargs) + + # This will either be a KSM Record, or Commander KeeperRecord + self.record = record + self._dag = None + if logger is None: + logger = logging.getLogger() + self.logger = logger + self.log_prefix = log_prefix + self.history_level = history_level + self.debug_level = debug_level + self.fail_on_corrupt = fail_on_corrupt + self.save_batch_count = save_batch_count + + self.agent = make_agent("user_service") + if agent is not None: + self.agent += "; " + agent + + self.auto_save = False + self.last_sync_point = -1 + + @property + def dag(self) -> DAG: + if self._dag is None: + + self._dag = DAG(conn=self.conn, + record=self.record, + # endpoint=PamEndpoints.SERVICE_LINKS, + graph_id=PamGraphId.SERVICE_LINKS, + auto_save=False, + logger=self.logger, + history_level=self.history_level, + debug_level=self.debug_level, + name="Discovery Service/Tasks", + fail_on_corrupt=self.fail_on_corrupt, + log_prefix=self.log_prefix, + save_batch_count=self.save_batch_count, + agent=self.agent) + + self._dag.load(sync_point=0) + + return self._dag + + def close(self): + """ + Clean up resources held by this UserService instance. + Releases the DAG instance and connection to prevent memory leaks. + """ + if self._dag is not None: + self._dag = None + self.conn = None + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures cleanup.""" + self.close() + return False + + def __del__(self): + self.close() + + @property + def has_graph(self) -> bool: + return self.dag.has_graph + + def reload(self): + self._dag.load(sync_point=0) + + def get_record_link(self, uid: str) -> DAGVertex: + return self.dag.get_vertex(uid) + + @staticmethod + def get_record_uid(discovery_vertex: DAGVertex) -> str: + """ + Get the record UID from the vertex + + """ + data = discovery_vertex.get_data() + if data is None: + raise Exception(f"The discovery vertex {discovery_vertex.uid} does not have a DATA edge. " + "Cannot get record UID.") + content = DiscoveryObject.get_discovery_object(discovery_vertex) + if content.record_uid is not None: + return content.record_uid + raise Exception(f"The discovery vertex {discovery_vertex.uid} data does not have a populated record UID.") + + def belongs_to(self, resource_uid: str, user_uid: str, acl: Optional[ServiceAcl] = None, + resource_name: Optional[str] = None, user_name: Optional[str] = None): + + """ + Link vault records using record UIDs. + + If a link already exists, no additional link will be created. + """ + + # Get thr record vertices. + # If a vertex does not exist, then add the vertex using the record UID + resource_vertex = self.dag.get_vertex(resource_uid) + if resource_vertex is None: + self.logger.debug(f"adding resource vertex for record UID {resource_uid} ({resource_name})") + resource_vertex = self.dag.add_vertex(uid=resource_uid, name=resource_name) + + user_vertex = self.dag.get_vertex(user_uid) + if user_vertex is None: + self.logger.debug(f"adding user vertex for record UID {user_uid} ({user_name})") + user_vertex = self.dag.add_vertex(uid=user_uid, name=user_name) + + self.logger.debug(f"user {user_vertex.uid} controls services on {resource_vertex.uid}") + + edge_type = EdgeType.LINK + if acl is not None: + edge_type = EdgeType.ACL + + user_vertex.belongs_to(resource_vertex, edge_type=edge_type, content=acl) + + def disconnect_from(self, resource_uid: str, user_uid: str): + resource_vertex = self.dag.get_vertex(resource_uid) + user_vertex = self.dag.get_vertex(user_uid) + user_vertex.disconnect_from(resource_vertex) + + def get_acl(self, resource_uid, user_uid) -> Optional[ServiceAcl]: + + """ + Get the service/task ACL between a resource and the user. + + """ + + resource_vertex = self.dag.get_vertex(resource_uid) + user_vertex = self.dag.get_vertex(user_uid) + if resource_vertex is None or user_vertex is None: + self.logger.debug(f"there is no acl between {resource_uid} and {user_uid}") + return ServiceAcl() + + acl_edge = user_vertex.get_edge(resource_vertex, edge_type=EdgeType.ACL) # type: DAGEdge + if acl_edge is None: + return None + + return acl_edge.content_as_object(ServiceAcl) + + def resource_has_link(self, resource_uid) -> bool: + """ + Is this resource linked to the configuration? + """ + + resource_vertex = self.dag.get_vertex(resource_uid) + if resource_vertex is None: + return False + link_edge = resource_vertex.get_edge(self.dag.get_root, edge_type=EdgeType.LINK) # type: DAGEdge + return link_edge is not None + + def get_resource_vertices(self, user_uid: str) -> List[DAGVertex]: + + """ + Get the resource vertices where the user is used for a service or task. + + """ + + user_vertex = self.dag.get_vertex(user_uid) + if user_vertex is None: + return [] + return user_vertex.belongs_to_vertices() + + def get_user_vertices(self, resource_uid: str) -> List[DAGVertex]: + + """ + Get the user vertices that control a service or task on this machine. + + """ + resource_vertex = self.dag.get_vertex(resource_uid) + if resource_vertex is None: + return [] + return resource_vertex.has_vertices() + + @staticmethod + def delete(vertex: DAGVertex): + if vertex is not None: + vertex.delete() + + def save(self): + if self.dag.has_graph: + self.logger.debug("saving the service user.") + self.dag.save(delta_graph=False) + else: + self.logger.debug("the service user graph does not contain any data, was not saved.") + + def to_dot(self, graph_format: str = "svg", show_version: bool = True, show_only_active_vertices: bool = True, + show_only_active_edges: bool = True, graph_type: str = "dot"): + + try: + mod = importlib.import_module("graphviz") + except ImportError: + raise Exception("Cannot to_dot(), graphviz module is not installed.") + + dot = getattr(mod, "Digraph")(comment=f"DAG for Services/Tasks", format=graph_format) + + if graph_type == "dot": + dot.attr(rankdir='RL') + elif graph_type == "twopi": + dot.attr(layout="twopi") + dot.attr(ranksep="10") + dot.attr(ratio="auto") + else: + dot.attr(layout=graph_type) + + self.logger.debug(f"have {len(self.dag.all_vertices)} vertices") + for v in self.dag.all_vertices: + if show_only_active_vertices is True and v.active is False: + continue + + tooltip = "" + + for edge in v.edges: + + color = "grey" + style = "solid" + + # To reduce the number of edges, only show the active edges + if edge.active: + color = "black" + style = "bold" + elif show_only_active_edges: + continue + + # If the vertex is not active, gray out the DATA edge + if edge.edge_type == EdgeType.DATA and v.active is False: + color = "grey" + + if edge.edge_type == EdgeType.DELETION: + style = "dotted" + + edge_tip = "" + if edge.edge_type == EdgeType.ACL and v.active is True: + content = edge.content_as_dict + red = "00" + green = "00" + blue = "00" + if content.get("is_service"): + red = "FF" + if content.get("is_task"): + blue = "FF" + if content.get("is_iis_pool"): + green = "FF" + if red == "FF" and blue == "FF" and green == "FF": + color = "#808080" + else: + color = f"#{red}{green}{blue}" + style = "bold" + + tooltip += f"TO {edge.head_uid}\\n" + for k, val in content.items(): + tooltip += f" * {k}={val}\\n" + tooltip += f"--------------------\\n\\n" + + label = DAG.EDGE_LABEL.get(edge.edge_type) + if label is None: + label = "UNK" + if edge.path is not None and edge.path != "": + label += f"\\npath={edge.path}" + if show_version: + label += f"\\nv={edge.version}" + + # tail, head (arrow side), label, ... + dot.edge(v.uid, edge.head_uid, label, style=style, fontcolor=color, color=color, tooltip=edge_tip) + + shape = "ellipse" + fillcolor = "white" + color = "black" + if not v.active: + fillcolor = "grey" + + label = f"uid={v.uid}" + dot.node(v.uid, label, color=color, fillcolor=fillcolor, style="filled", shape=shape, tooltip=tooltip) + + return dot + + def _get_directory_user_vertices(self, configuration_vertex: DAGVertex, domain_name: str) -> List[DAGVertex]: + """ + Find the directory in the graph and return of list of user vertices. + """ + + domain_name = domain_name.lower() + + user_vertices: List[DAGVertex] = [] + + # Check the configuration; it might provide domains. + # Need to only include the user vertices. + # If we find it here, we don't need to check for directories; so return with the list. + config_content = DiscoveryObject.get_discovery_object(configuration_vertex) + if config_content.record_type in DOMAIN_USER_CONFIGS: + config_domains = config_content.item.info.get("domains", []) + self.logger.debug(f" the provider provides domains: {config_domains}") + for config_domain in config_domains: + if config_domain.lower() == domain_name: + self.logger.debug(f" matched for {domain_name}") + for vertex in configuration_vertex.has_vertices(): + content = DiscoveryObject.get_discovery_object(vertex) + if content.record_type == PAM_USER: + user_vertices.append(vertex) + self.logger.debug(f" found {len(user_vertices)} users for {domain_name}") + return user_vertices + + self.logger.debug(" checking pam directories for users") + + # If the configuration did not have domain users, or there were do users, check the PAM Directories. + for resource_vertex in configuration_vertex.has_vertices(): + content = DiscoveryObject.get_discovery_object(resource_vertex) + if content.record_type != PAM_DIRECTORY: + continue + if content.name.lower() == domain_name: + user_vertices = resource_vertex.has_vertices() + self.logger.debug(f" found {len(user_vertices)} users for {domain_name}") + break + + return user_vertices + + def _get_user_vertices(self, + infra_resource_content: DiscoveryObject, + infra_resource_vertex: DAGVertex) -> List[DAGVertex]: + + self.logger.debug(f" getting users for {infra_resource_content.name}") + + # If this machine joined to a directory. + # Since this a Windows machine, we can have only one joined directory; take the first one. + domain_name = None + if len(infra_resource_content.item.facts.directories) > 0: + domain_name = infra_resource_content.item.facts.directories[0].domain + self.logger.debug(f" joined to {domain_name}") + + # Get a list of local users. + # If the machine is joined to a domain, get a list of users from that domain. + user_vertices = infra_resource_vertex.has_vertices() + self.logger.debug(f" found {len(user_vertices)} local users") + if domain_name is not None: + user_vertices += self._get_directory_user_vertices( + configuration_vertex=infra_resource_vertex.belongs_to_vertices()[0], + domain_name=domain_name + ) + + self.logger.debug(f" found {len(user_vertices)} total users") + + return user_vertices + + def _connect_service_users(self, + infra_resource_content: DiscoveryObject, + infra_resource_vertex: DAGVertex, + services: List[FactsNameUser]): + + self.logger.debug(f"processing services for {infra_resource_content.description} ({infra_resource_vertex.uid})") + + # We don't care about the name of the service, we just need a list users. + lookup = {} + for service in services: + lookup[service.user.lower()] = True + + infra_user_vertices = self._get_user_vertices(infra_resource_content=infra_resource_content, + infra_resource_vertex=infra_resource_vertex) + + for infra_user_vertex in infra_user_vertices: + infra_user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) + if infra_user_content.record_uid is None: + continue + if user_in_lookup( + lookup=lookup, + user=infra_user_content.item.user, + name=infra_user_content.name, + source=infra_user_content.item.source): + self.logger.debug(f" * found user for service: {infra_user_content.item.user}") + acl = self.get_acl(infra_resource_content.record_uid, infra_user_content.record_uid) + if acl is None: + acl = ServiceAcl() + acl.is_service = True + self.belongs_to( + resource_uid=infra_resource_content.record_uid, + resource_name=infra_resource_content.uid, + user_uid=infra_user_content.record_uid, + user_name=infra_user_content.uid, + acl=acl) + + def _connect_task_users(self, + infra_resource_content: DiscoveryObject, + infra_resource_vertex: DAGVertex, + tasks: List[FactsNameUser]): + + self.logger.debug(f"processing tasks for {infra_resource_content.description} ({infra_resource_vertex.uid})") + + # We don't care about the name of the tasks, we just need a list users. + lookup = {} + for task in tasks: + lookup[task.user.lower()] = True + + infra_user_vertices = self._get_user_vertices(infra_resource_content=infra_resource_content, + infra_resource_vertex=infra_resource_vertex) + + for infra_user_vertex in infra_user_vertices: + infra_user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) + if infra_user_content.record_uid is None: + continue + if user_in_lookup( + lookup=lookup, + user=infra_user_content.item.user, + name=infra_user_content.name, + source=infra_user_content.item.source): + self.logger.debug(f" * found user for task: {infra_user_content.item.user}") + acl = self.get_acl(infra_resource_content.record_uid, infra_user_content.record_uid) + if acl is None: + acl = ServiceAcl() + acl.is_task = True + self.belongs_to( + resource_uid=infra_resource_content.record_uid, + resource_name=infra_resource_content.uid, + user_uid=infra_user_content.record_uid, + user_name=infra_user_content.uid, + acl=acl) + + def _connect_iis_pool_users(self, + infra_resource_content: DiscoveryObject, + infra_resource_vertex: DAGVertex, + iis_pools: List[FactsNameUser]): + + self.logger.debug(f"processing iis pools for " + f"{infra_resource_content.description} ({infra_resource_vertex.uid})") + + # We don't care about the name of the tasks, we just need a list users. + lookup = {} + for iis_pool in iis_pools: + lookup[iis_pool.user.lower()] = True + + infra_user_vertices = self._get_user_vertices(infra_resource_content=infra_resource_content, + infra_resource_vertex=infra_resource_vertex) + + for infra_user_vertex in infra_user_vertices: + infra_user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) + if infra_user_content.record_uid is None: + continue + if user_in_lookup( + lookup=lookup, + user=infra_user_content.item.user, + name=infra_user_content.name, + source=infra_user_content.item.source): + self.logger.debug(f" * found user for iis pool: {infra_user_content.item.user}") + acl = self.get_acl(infra_resource_content.record_uid, infra_user_content.record_uid) + if acl is None: + acl = ServiceAcl() + acl.is_iis_pool = True + self.belongs_to( + resource_uid=infra_resource_content.record_uid, + resource_name=infra_resource_content.uid, + user_uid=infra_user_content.record_uid, + user_name=infra_user_content.uid, + acl=acl) + + def _validate_users(self, + infra_resource_content: DiscoveryObject, + infra_resource_vertex: DAGVertex): + + """ + This method will check to see if a resource's users' ACL edges are still valid. + + This check will check both local and directory users. + """ + + self.logger.debug(f"validate existing user service edges to see if still valid to " + f"{infra_resource_content.name}") + + service_lookup = {} + for service in infra_resource_content.item.facts.services: + service_lookup[service.user.lower()] = True + + task_lookup = {} + for task in infra_resource_content.item.facts.tasks: + task_lookup[task.user.lower()] = True + + iis_pool_lookup = {} + for iss_pool in infra_resource_content.item.facts.iis_pools: + iis_pool_lookup[iss_pool.user.lower()] = True + + # Get the user service resource vertex. + # If it does not exist, then we cannot validate users. + user_service_resource_vertex = self.dag.get_vertex(infra_resource_content.record_uid) + if user_service_resource_vertex is None: + return + + infra_dag = infra_resource_vertex.dag + + # The users from the service graph will contain local and directory users. + for user_service_user_vertex in user_service_resource_vertex.has_vertices(): + acl_edge = user_service_user_vertex.get_edge( + user_service_resource_vertex, edge_type=EdgeType.ACL) # type: DAGEdge + if acl_edge is None: + self.logger.info(f"User record {user_service_user_vertex.uid} does not have an ACL edge to " + f"{user_service_resource_vertex.uid} for user services.") + continue + + found_service_acl = False + found_task_acl = False + found_iis_pool_acl = False + changed = False + + acl = acl_edge.content_as_object(ServiceAcl) + + # This will check the entire infrastructure graph for the user with the record UID. + # This could be a local or directory users. + user = infra_dag.search_content({"record_type": PAM_USER, "record_uid": user_service_user_vertex.uid}) + infra_user_content = None + found_user = len(user) > 0 + if found_user: + infra_user_vertex = user[0] + if infra_user_vertex.active is False: + found_user = False + else: + infra_user_content = DiscoveryObject.get_discovery_object(infra_user_vertex) + + if not found_user: + self.disconnect_from(user_service_resource_vertex.uid, user_service_user_vertex.uid) + continue + + check_list = user_check_list( + user=infra_user_content.item.user, + name=infra_user_content.name, + source=infra_user_content.item.source + ) + + if acl.is_service: + for check_user in check_list: + if check_user in service_lookup: + found_service_acl = True + break + if not found_service_acl: + acl.is_service = False + changed = True + + if acl.is_task: + for check_user in check_list: + if check_user in task_lookup: + found_task_acl = True + break + if not found_task_acl: + acl.is_task = False + changed = True + + if acl.is_iis_pool: + for check_user in check_list: + if check_user in iis_pool_lookup: + found_iis_pool_acl = True + break + if not found_iis_pool_acl: + acl.is_iis_pool = False + changed = True + + if (found_service_acl is True or found_task_acl is True or found_iis_pool_acl is True) or changed is True: + self.logger.debug(f"user {user_service_user_vertex.uid}(US) to " + f"{user_service_resource_vertex.uid} updated") + self.belongs_to(user_service_resource_vertex.uid, user_service_user_vertex.uid, acl) + elif found_service_acl is False and found_task_acl is False and found_iis_pool_acl is False: + self.logger.debug(f"user {user_service_user_vertex.uid}(US) to " + f"{user_service_resource_vertex.uid} disconnected") + self.disconnect_from(user_service_resource_vertex.uid, user_service_user_vertex.uid) + + self.logger.debug(f"DONE validate existing user") + + def run(self, infra: Optional[Infrastructure] = None, **kwargs): + """ + Map users to services/tasks on machines. + + IMPORTANT: To avoid memory leaks, pass an existing Infrastructure instance + instead of letting this method create a new one. Example: + user_service.run(infra=process.infra) + """ + + self.logger.debug("") + self.logger.debug("##########################################################################################") + self.logger.debug("# MAP USER TO MACHINE FOR SERVICE/TASKS") + self.logger.debug("") + + # If an instance of Infrastructure is not passed in. + # NOTE: Creating a new Infrastructure instance here can cause memory leaks. + # Prefer passing an existing instance via the infra parameter. + _cleanup_infra_on_exit = False + if infra is None: + self.logger.warning("Creating new Infrastructure instance - consider passing existing instance to avoid memory leaks") + + # Get ksm from the connection. + # However, this might be a local connection, so check first. + # Local connections don't need ksm. + if hasattr(self.conn, "ksm"): + kwargs["ksm"] = getattr(self.conn, "ksm") + + # Get the entire infrastructure graph; sync point = 0 + infra = Infrastructure(record=self.record, **kwargs) + infra.load() + _cleanup_infra_on_exit = True + + # Work ourselves to the configuration vertex. + infra_root_vertex = infra.get_root + infra_config_vertex = infra_root_vertex.has_vertices()[0] + + # For the user service, the root vertex is the equivalent to the infrastructure configuration vertex. + user_service_config_vertex = self.dag.get_root + + # Find all the resources that are machines. + for infra_resource_vertex in infra_config_vertex.has_vertices(): + if infra_resource_vertex.active is False or infra_resource_vertex.has_data is False: + continue + infra_resource_content = DiscoveryObject.get_discovery_object(infra_resource_vertex) + if infra_resource_content.record_type == PAM_MACHINE: + + self.logger.debug(f"checking {infra_resource_content.name}") + + # Check the user on the resource if they still are part of a service or task. + self._validate_users(infra_resource_content, infra_resource_vertex) + + # Do we have services, tasks, iis_pools that are run as a user with a password? + if infra_resource_content.item.facts.has_service_items is True: + + # If the resource does not exist in the user service graph, add a vertex and link it to the + # user service root/configuration vertex. + user_service_resource_vertex = self.dag.get_vertex(infra_resource_content.record_uid) + if user_service_resource_vertex is None: + user_service_resource_vertex = self.dag.add_vertex(uid=infra_resource_content.record_uid, + name=infra_resource_content.description) + if not user_service_config_vertex.has(user_service_resource_vertex): + user_service_resource_vertex.belongs_to_root(EdgeType.LINK) + + # Do we have services that are run as a user with a password? + if infra_resource_content.item.facts.has_services is True: + self._connect_service_users( + infra_resource_content, + infra_resource_vertex, + infra_resource_content.item.facts.services) + + # Do we have tasks that are run as a user with a password? + if infra_resource_content.item.facts.has_tasks is True: + self._connect_task_users( + infra_resource_content, + infra_resource_vertex, + infra_resource_content.item.facts.tasks) + + # Do we have tasks that are run as a user with a password? + if infra_resource_content.item.facts.has_iis_pools is True: + self._connect_iis_pool_users( + infra_resource_content, + infra_resource_vertex, + infra_resource_content.item.facts.iis_pools) + + self.save() + + # Clean up the Infrastructure instance if we created it + if _cleanup_infra_on_exit and infra is not None: + self.logger.debug("cleaning up Infrastructure instance created in run()") + infra.close() diff --git a/keepersdk-package/src/keepersdk/helpers/pam_config_facade.py b/keepersdk-package/src/keepersdk/helpers/pam_config_facade.py new file mode 100644 index 00000000..ff30295d --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/pam_config_facade.py @@ -0,0 +1,77 @@ + +from ..vault.record_facades import TypedRecordFacade, string_getter, string_list_getter, string_setter +from ..vault import vault_record, record_types + + +class PamConfigurationRecordFacade(TypedRecordFacade): + _controller_uid_getter = string_getter('controllerUid') + _controller_uid_setter = string_setter('controllerUid') + _folder_uid_getter = string_getter('folderUid') + _folder_uid_setter = string_setter('folderUid') + _resource_ref_getter = string_list_getter('resourceRef') + _file_ref_getter = string_getter('_file_ref') + + def __init__(self): + super(PamConfigurationRecordFacade, self).__init__() + self._pam_resources = None + self._port_mapping = None + self._file_ref = None + + def load_typed_fields(self): + if self.record: + self._pam_resources = next((x for x in self.record.fields if x.type == 'pamResources'), None) + if not self._pam_resources: + self._pam_resources = vault_record.TypedField.new_field('pamResources', []) + self.record.fields.append(self._pam_resources) + + if len(self._pam_resources.value) > 0: + if not isinstance(self._pam_resources.value[0], dict): + self._pam_resources.value.clear() + + if len(self._pam_resources.value) == 0: + if 'pamResources' in record_types.FieldTypes and isinstance(record_types.FieldTypes['pamResources'].value, dict): + value = record_types.FieldTypes['pamResources'].value.copy() + else: + value = {} + self._pam_resources.value.append(value) + + self._port_mapping = next((x for x in self.record.fields + if x.type == 'multiline' and x.label == 'portMapping'), None) + if self._port_mapping is None: + self._port_mapping = vault_record.TypedField.new_field('multiline', [], field_label='portMapping') + self.record.fields.append(self._port_mapping) + + self._file_ref = next((x for x in self.record.fields if x.type == 'fileRef' and x.label == 'rotationScripts'), None) + if self._file_ref is None: + self._file_ref = vault_record.TypedField.new_field('fileRef', [], field_label='rotationScripts') + self.record.fields.append(self._file_ref) + else: + self._pam_resources = None + self._port_mapping = None + self._file_ref = None + + super(PamConfigurationRecordFacade, self).load_typed_fields() + + @property + def controller_uid(self): + return PamConfigurationRecordFacade._controller_uid_getter(self) + + @controller_uid.setter + def controller_uid(self, value): + PamConfigurationRecordFacade._controller_uid_setter(self, value) + + @property + def folder_uid(self): + return PamConfigurationRecordFacade._folder_uid_getter(self) + + @folder_uid.setter + def folder_uid(self, value): + PamConfigurationRecordFacade._folder_uid_setter(self, value) + + @property + def resource_ref(self): + return PamConfigurationRecordFacade._resource_ref_getter(self) + + @property + def rotation_scripts(self): + return PamConfigurationRecordFacade._file_ref_getter(self) diff --git a/keepersdk-package/src/keepersdk/helpers/pam_user_record_facade.py b/keepersdk-package/src/keepersdk/helpers/pam_user_record_facade.py new file mode 100644 index 00000000..cae00641 --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/pam_user_record_facade.py @@ -0,0 +1,123 @@ +from ..vault.record_facades import TypedRecordFacade, string_getter, string_setter, TypedField + +def boolean_getter(name): + def getter(obj): + field = getattr(obj, name) + if isinstance(field, TypedField): + value = field.value[0] if len(field.value) > 0 else None + if value is None: + return None + elif isinstance(value, bool) is True: + return value + + if str(value).lower() in ['true', 'yes', '1', 'on']: + return True + elif str(value).lower() in ['false', 'no', '0', 'off']: + return False + return None + return getter + +def boolean_setter(name): + def setter(obj, value): + field = getattr(obj, name) + if isinstance(field, TypedField): + if value is not None: + if isinstance(value, bool) is not True: + if str(value).lower() in ['true', 'yes', '1', 'on']: + value = True + elif str(value).lower() in ['false', 'no', '0', 'off']: + value = False + if len(field.value) > 0: + field.value[0] = value + else: + field.value.append(value) + else: + field.value.clear() + return setter + +class PamUserRecordFacade(TypedRecordFacade): + _login_getter = string_getter('_login') + _login_setter = string_setter('_login') + _password_getter = string_getter('_password') + _password_setter = string_setter('_password') + _distinguishedName_getter = string_getter('_distinguishedName') + _distinguishedName_setter = string_setter('_distinguishedName') + _connectDatabase_getter = string_getter('_connectDatabase') + _connectDatabase_setter = string_setter('_connectDatabase') + _managed_getter = boolean_getter('_managed') + _managed_setter = boolean_setter('_managed') + _oneTimeCode_getter = string_getter('_oneTimeCode') + _oneTimeCode_setter = string_setter('_oneTimeCode') + + def __init__(self): + super(PamUserRecordFacade, self).__init__() + self._login = None + self._password = None + self._distinguishedName = None + self._connectDatabase = None + self._managed = None + self._oneTimeCode = None + + @property + def login(self): + return PamUserRecordFacade._login_getter(self) + + @login.setter + def login(self, value): + PamUserRecordFacade._login_setter(self, value) + + @property + def password(self): + return PamUserRecordFacade._password_getter(self) + + @password.setter + def password(self, value): + PamUserRecordFacade._password_setter(self, value) + + @property + def distinguishedName(self): + return PamUserRecordFacade._distinguishedName_getter(self) + + @distinguishedName.setter + def distinguishedName(self, value): + PamUserRecordFacade._distinguishedName_setter(self, value) + + @property + def connectDatabase(self): + return PamUserRecordFacade._connectDatabase_getter(self) + + @connectDatabase.setter + def connectDatabase(self, value): + PamUserRecordFacade._connectDatabase_setter(self, value) + + @property + def managed(self): + return PamUserRecordFacade._connectDatabase_getter(self) + + @managed.setter + def managed(self, value): + PamUserRecordFacade._managed_setter(self, value) + + @property + def oneTimeCode(self): + return PamUserRecordFacade._oneTimeCode_getter(self) + + @oneTimeCode.setter + def oneTimeCode(self, value): + PamUserRecordFacade._oneTimeCode_setter(self, value) + + def load_typed_fields(self): + if self.record: + self.record.type_name = 'pamUser' + for attr in ["login", "password", "distinguishedName", "connectDatabase", "managed", "oneTimeCode"]: + attr_prv = f"_{attr}" + value = next((x for x in self.record.fields if x.type == attr), None) + setattr(self, attr_prv, value) + if value is None: + value = TypedField.new_field(attr, '') + setattr(self, attr_prv, value) + self.record.fields.append(value) + else: + for attr in ["_login", "_password", "_distinguishedName", "_connectDatabase", "_managed", "_oneTimeCode"]: + setattr(self, attr, None) + super(PamUserRecordFacade, self).load_typed_fields() diff --git a/keepersdk-package/src/keepersdk/helpers/tunnel/tunnel_graph.py b/keepersdk-package/src/keepersdk/helpers/tunnel/tunnel_graph.py new file mode 100644 index 00000000..683eeccd --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/tunnel/tunnel_graph.py @@ -0,0 +1,540 @@ +import logging + +from . import tunnel_utils + +from ..keeper_dag.connection.commander import Connection +from ..keeper_dag.dag_types import EdgeType, RefType, PamEndpoints +from ..keeper_dag.dag import DAG +from ..keeper_dag.dag_vertex import DAGVertex +from ..keeper_dag.dag_crypto import generate_random_bytes + +from ...vault import vault_online, vault_record + +logger = logging.getLogger(__name__) + +def get_vertex_content(vertex): + return_content = None + if vertex is None: + return return_content + try: + return_content = vertex.content_as_dict + except Exception as e: + logger.debug(f"Error getting vertex content: {e}") + return_content = None + return return_content + + +class TunnelDAG: + def __init__(self, vault: vault_online.VaultOnline, encrypted_session_token, encrypted_transmission_key, record_uid: str, is_config=False): + config_uid = None + if not is_config: + config_uid = tunnel_utils.get_config_uid(vault, encrypted_session_token, encrypted_transmission_key, record_uid) + if not config_uid: + config_uid = record_uid + self.record = vault_record.PasswordRecord() + self.record.record_uid = config_uid + self.record.record_key = generate_random_bytes(32) + self.encrypted_session_token = encrypted_session_token + self.encrypted_transmission_key = encrypted_transmission_key + self.conn = Connection(vault=vault, encrypted_transmission_key=self.encrypted_transmission_key, + encrypted_session_token=self.encrypted_session_token, + use_write_protobuf=True + ) + self.linking_dag = DAG(conn=self.conn, record=self.record, graph_id=0, write_endpoint=PamEndpoints.PAM) + try: + self.linking_dag.load() + except Exception as e: + import logging + logging.debug(f"Error loading config: {e}") + + def resource_belongs_to_config(self, resource_uid): + if not self.linking_dag.has_graph: + return False + resource_vertex = self.linking_dag.get_vertex(resource_uid) + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + return resource_vertex and config_vertex.has(resource_vertex, EdgeType.LINK) + + def user_belongs_to_config(self, user_uid): + if not self.linking_dag.has_graph: + return False + user_vertex = self.linking_dag.get_vertex(user_uid) + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + res_content = False + if user_vertex and config_vertex and config_vertex.has(user_vertex, EdgeType.ACL): + acl_edge = user_vertex.get_edge(config_vertex, EdgeType.ACL) + _content = acl_edge.content_as_dict + res_content = _content.get('belongs_to', False) if _content else False + return res_content + + def check_tunneling_enabled_config(self, enable_connections=None, enable_tunneling=None, + enable_rotation=None, enable_session_recording=None, + enable_typescript_recording=None, remote_browser_isolation=None): + if not self.linking_dag.has_graph: + return False + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + content = get_vertex_content(config_vertex) + if content is None or not content.get('allowedSettings'): + return False + + allowed_settings = content['allowedSettings'] + if enable_connections and not allowed_settings.get("connections"): + return False + if enable_tunneling and not allowed_settings.get("portForwards"): + return False + if enable_rotation and not allowed_settings.get("rotation"): + return False + if allowed_settings.get("connections") and allowed_settings["connections"]: + if enable_session_recording and not allowed_settings.get("sessionRecording"): + return False + if enable_typescript_recording and not allowed_settings.get("typescriptRecording"): + return False + if remote_browser_isolation and not allowed_settings.get("remoteBrowserIsolation"): + return False + return True + + @staticmethod + def _convert_allowed_setting(value): + """Converts on/off/default|any to True/False/None""" + if value is None or isinstance(value, bool): + return value + return {"on": True, "off": False}.get(str(value).lower(), None) + + def edit_tunneling_config(self, connections=None, tunneling=None, + rotation=None, session_recording=None, + typescript_recording=None, + remote_browser_isolation=None): + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + config_vertex = self.linking_dag.add_vertex(uid=self.record.record_uid, vertex_type=RefType.PAM_NETWORK) + + if config_vertex.vertex_type != RefType.PAM_NETWORK: + config_vertex.vertex_type = RefType.PAM_NETWORK + content = get_vertex_content(config_vertex) + if content and content.get('allowedSettings'): + allowed_settings = dict(content['allowedSettings']) + del content['allowedSettings'] + content = {'allowedSettings': allowed_settings} + + if content is None: + content = {'allowedSettings': {}} + if 'allowedSettings' not in content: + content['allowedSettings'] = {} + + allowed_settings = content['allowedSettings'] + dirty = False + + if connections is not None: + connections = self._convert_allowed_setting(connections) + if connections != allowed_settings.get("connections", None): + dirty = True + if connections is None: + allowed_settings.pop("connections", None) + else: + allowed_settings["connections"] = connections + + if tunneling is not None: + tunneling = self._convert_allowed_setting(tunneling) + if tunneling != allowed_settings.get("portForwards", None): + dirty = True + if tunneling is None: + allowed_settings.pop("portForwards", None) + else: + allowed_settings["portForwards"] = tunneling + + if rotation is not None: + rotation = self._convert_allowed_setting(rotation) + if rotation != allowed_settings.get("rotation", None): + dirty = True + if rotation is None: + allowed_settings.pop("rotation", None) + else: + allowed_settings["rotation"] = rotation + + if session_recording is not None: + session_recording = self._convert_allowed_setting(session_recording) + if session_recording != allowed_settings.get("sessionRecording", None): + dirty = True + if session_recording is None: + allowed_settings.pop("sessionRecording", None) + else: + allowed_settings["sessionRecording"] = session_recording + + if typescript_recording is not None: + typescript_recording = self._convert_allowed_setting(typescript_recording) + if typescript_recording != allowed_settings.get("typescriptRecording", None): + dirty = True + if typescript_recording is None: + allowed_settings.pop("typescriptRecording", None) + else: + allowed_settings["typescriptRecording"] = typescript_recording + + if remote_browser_isolation is not None: + remote_browser_isolation = self._convert_allowed_setting(remote_browser_isolation) + if remote_browser_isolation != allowed_settings.get("remoteBrowserIsolation", None): + dirty = True + if remote_browser_isolation is None: + allowed_settings.pop("remoteBrowserIsolation", None) + else: + allowed_settings["remoteBrowserIsolation"] = remote_browser_isolation + + if dirty: + config_vertex.add_data(content=content, path='meta', needs_encryption=False) + self.linking_dag.save() + + def get_all_owners(self, uid): + owners = [] + if self.linking_dag.has_graph: + vertex = self.linking_dag.get_vertex(uid) + if vertex: + owners = [owner.uid for owner in vertex.belongs_to_vertices()] + return owners + + def user_belongs_to_resource(self, user_uid, resource_uid): + user_vertex = self.linking_dag.get_vertex(user_uid) + resource_vertex = self.linking_dag.get_vertex(resource_uid) + res_content = False + if user_vertex and resource_vertex and resource_vertex.has(user_vertex, EdgeType.ACL): + acl_edge = user_vertex.get_edge(resource_vertex, EdgeType.ACL) + _content = acl_edge.content_as_dict + res_content = _content.get('belongs_to', False) if _content else False + return res_content + + def get_resource_uid(self, user_uid): + if not self.linking_dag.has_graph: + return None + resources = self.get_all_owners(user_uid) + if len(resources) > 0: + for resource in resources: + if self.user_belongs_to_resource(user_uid, resource): + return resource + return None + + def link_resource_to_config(self, resource_uid): + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + config_vertex = self.linking_dag.add_vertex(uid=self.record.record_uid) + + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None: + resource_vertex = self.linking_dag.add_vertex(uid=resource_uid) + + if not config_vertex.has(resource_vertex, EdgeType.LINK): + resource_vertex.belongs_to(config_vertex, EdgeType.LINK) + self.linking_dag.save() + + def link_user_to_config(self, user_uid): + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + config_vertex = self.linking_dag.add_vertex(uid=self.record.record_uid) + self.link_user(user_uid, config_vertex, belongs_to=True, is_iam_user=True) + + def link_user_to_config_with_options(self, user_uid, is_admin=None, belongs_to=None, is_iam_user=None): + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + config_vertex = self.linking_dag.add_vertex(uid=self.record.record_uid) + + # self.link_user(user_uid, config_vertex, is_admin, belongs_to, is_iam_user) + source_vertex = config_vertex + user_vertex = self.linking_dag.get_vertex(user_uid) + if user_vertex is None: + user_vertex = self.linking_dag.add_vertex(uid=user_uid, vertex_type=RefType.PAM_USER) + + # switching to 3-state on/off/default: on/true, off/false, + # None = Keep existing, 'default' = Reset to default (remove from dict) + states = {'on': True, 'off': False, 'default': '', 'none': None} + + content = { + "belongs_to": states.get(str(belongs_to).lower()), + "is_admin": states.get(str(is_admin).lower()), + "is_iam_user": states.get(str(is_iam_user).lower()) + } + if user_vertex.vertex_type != RefType.PAM_USER: + user_vertex.vertex_type = RefType.PAM_USER + + dirty = False + if source_vertex.has(user_vertex, EdgeType.ACL): + acl_edge = user_vertex.get_edge(source_vertex, EdgeType.ACL) + existing_content = acl_edge.content_as_dict or {} + old_content = existing_content.copy() + for key in list(existing_content.keys()): + if content.get(key) is not None: + if content[key] == '': + existing_content.pop(key) + elif content[key] in (True, False): + existing_content[key] = content[key] + content = {k: v for k, v in content.items() if v not in (None, '')} + for k, v in content.items(): + existing_content.setdefault(k, v) + if existing_content != old_content: + dirty = True + + if dirty: + user_vertex.belongs_to(source_vertex, EdgeType.ACL, content=existing_content) + # user_vertex.add_data(content=existing_content, needs_encryption=False) + self.linking_dag.save() + else: + content = {k: v for k, v in content.items() if v not in (None, '')} + user_vertex.belongs_to(source_vertex, EdgeType.ACL, content=content) + self.linking_dag.save() + + def unlink_user_from_resource(self, user_uid, resource_uid) -> bool: + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None or not self.resource_belongs_to_config(resource_uid): + logger.error(f"Resource {resource_uid} does not belong to the configuration") + return False + + user_vertex = self.linking_dag.get_vertex(user_uid) + if user_vertex is None or user_vertex.vertex_type != RefType.PAM_USER: + return False + + if resource_vertex.has(user_vertex, EdgeType.ACL): + acl_edge = user_vertex.get_edge(resource_vertex, EdgeType.ACL) + edge_content = acl_edge.content_as_dict or {} + link_keys = ('belongs_to', 'is_admin') # "is_iam_user" + dirty = any(key in link_keys for key in edge_content) + if dirty: + for link_key in link_keys: + edge_content.pop(link_key, None) + user_vertex.belongs_to(resource_vertex, EdgeType.ACL, content=edge_content) + self.linking_dag.save() + return True + + return False + + def link_user_to_resource(self, user_uid, resource_uid, is_admin=None, belongs_to=None): + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None or not self.resource_belongs_to_config(resource_uid): + logger.error(f"Resource {resource_uid} does not belong to the configuration") + return False + self.link_user(user_uid, resource_vertex, is_admin, belongs_to) + return None + + def link_user(self, user_uid, source_vertex: DAGVertex, is_admin=None, belongs_to=None, is_iam_user=None): + + user_vertex = self.linking_dag.get_vertex(user_uid) + if user_vertex is None: + user_vertex = self.linking_dag.add_vertex(uid=user_uid, vertex_type=RefType.PAM_USER) + + content = {} + dirty = False + if belongs_to is not None: + content["belongs_to"] = bool(belongs_to) + if is_admin is not None: + content["is_admin"] = bool(is_admin) + if is_iam_user is not None: + content["is_iam_user"] = bool(is_iam_user) + + if user_vertex.vertex_type != RefType.PAM_USER: + user_vertex.vertex_type = RefType.PAM_USER + + if source_vertex.has(user_vertex, EdgeType.ACL): + acl_edge = user_vertex.get_edge(source_vertex, EdgeType.ACL) + existing_content = acl_edge.content_as_dict or {} + for key in existing_content: + if key not in content: + content[key] = existing_content[key] + if content != existing_content: + dirty = True + + if dirty: + user_vertex.belongs_to(source_vertex, EdgeType.ACL, content=content) + self.linking_dag.save() + else: + user_vertex.belongs_to(source_vertex, EdgeType.ACL, content=content) + self.linking_dag.save() + + def get_all_admins(self): + if not self.linking_dag.has_graph: + return [] + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + return [] + admins = [] + for user_vertex in config_vertex.has_vertices(EdgeType.ACL): + acl_edge = user_vertex.get_edge(config_vertex, EdgeType.ACL) + if acl_edge: + content = acl_edge.content_as_dict + if content.get('is_admin'): + admins.append(user_vertex.uid) + return admins + + def check_if_resource_has_admin(self, resource_uid): + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None: + return False + for user_vertex in resource_vertex.has_vertices(EdgeType.ACL): + acl_edge = user_vertex.get_edge(resource_vertex, EdgeType.ACL) + if acl_edge: + content = acl_edge.content_as_dict + if content.get('is_admin'): + return user_vertex.uid + return False + + def check_if_resource_allowed(self, resource_uid, setting): + resource_vertex = self.linking_dag.get_vertex(resource_uid) + content = get_vertex_content(resource_vertex) + return content.get('allowedSettings', {}).get(setting, False) if content else False + + def get_resource_setting(self, resource_uid: str, settings_name: str, setting: str) -> str: + # Settings are tri-state (on|off|default) mapped to true|false|missing in JSON + # When set to "default" (missing from JSON) that means look higher up the hierarchy + # ex. rotation: user -> machine -> pam_config -> Gobal Default settings + # Note: Different clients (even different client versions) + # may have different view on these defaults (Commander, Web Vault, etc.) + resource_vertex = self.linking_dag.get_vertex(resource_uid) + content = get_vertex_content(resource_vertex) + res = '' + if content and isinstance(content, dict): + if settings_name in content and isinstance(content[settings_name], dict): + if setting in content[settings_name]: + value = content[settings_name][setting] + if isinstance(value, bool): + res = {True: 'on', False: 'off'}[value] + else: + res = str(value) + else: + res = 'default' + + return res + + def set_resource_allowed(self, resource_uid, tunneling=None, connections=None, rotation=None, + session_recording=None, typescript_recording=None, remote_browser_isolation=None, + allowed_settings_name='allowedSettings', is_config=False, + v_type: RefType=str(RefType.PAM_MACHINE)): + v_type = RefType(v_type) + allowed_ref_types = [RefType.PAM_MACHINE, RefType.PAM_DATABASE, RefType.PAM_DIRECTORY, RefType.PAM_BROWSER] + if v_type not in allowed_ref_types: + # default to machine + v_type = RefType.PAM_MACHINE + + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None: + resource_vertex = self.linking_dag.add_vertex(uid=resource_uid, vertex_type=v_type) + + if resource_vertex.vertex_type not in allowed_ref_types: + resource_vertex.vertex_type = v_type + if is_config: + resource_vertex.vertex_type = RefType.PAM_NETWORK + dirty = False + content = get_vertex_content(resource_vertex) + if content is None: + content = {allowed_settings_name: {}} + dirty = True + if allowed_settings_name not in content: + content[allowed_settings_name] = {} + dirty = True + + settings = content[allowed_settings_name] + + # When no value in allowedSettings: client will substitute with default + # rotation defaults to True, everything else defaults to False + + # switching to 3-state on/off/default: on/true, off/false, + # None = Keep existing, 'default' = Reset to default (remove from dict) + if connections is not None: + connections = self._convert_allowed_setting(connections) + if connections != settings.get("connections", None): + dirty = True + if connections is None: + settings.pop("connections", None) + else: + settings["connections"] = connections + + if tunneling is not None: + tunneling = self._convert_allowed_setting(tunneling) + if tunneling != settings.get("portForwards", None): + dirty = True + if tunneling is None: + settings.pop("portForwards", None) + else: + settings["portForwards"] = tunneling + + if rotation is not None: + rotation = self._convert_allowed_setting(rotation) + if rotation != settings.get("rotation", None): + dirty = True + if rotation is None: + settings.pop("rotation", None) + else: + settings["rotation"] = rotation + + if session_recording is not None: + session_recording = self._convert_allowed_setting(session_recording) + if session_recording != settings.get("sessionRecording", None): + dirty = True + if session_recording is None: + settings.pop("sessionRecording", None) + else: + settings["sessionRecording"] = session_recording + + if typescript_recording is not None: + typescript_recording = self._convert_allowed_setting(typescript_recording) + if typescript_recording != settings.get("typescriptRecording", None): + dirty = True + if typescript_recording is None: + settings.pop("typescriptRecording", None) + else: + settings["typescriptRecording"] = typescript_recording + + if remote_browser_isolation is not None: + remote_browser_isolation = self._convert_allowed_setting(remote_browser_isolation) + if remote_browser_isolation != settings.get("remoteBrowserIsolation", None): + dirty = True + if remote_browser_isolation is None: + settings.pop("remoteBrowserIsolation", None) + else: + settings["remoteBrowserIsolation"] = remote_browser_isolation + + if dirty: + resource_vertex.add_data(content=content, path='meta', needs_encryption=False) + self.linking_dag.save() + + def is_tunneling_config_set_up(self, resource_uid): + if not self.linking_dag.has_graph: + return False + resource_vertex = self.linking_dag.get_vertex(resource_uid) + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + return resource_vertex and config_vertex and config_vertex in resource_vertex.belongs_to_vertices() + + def remove_from_dag(self, uid): + if not self.linking_dag.has_graph: + return True + + vertex = self.linking_dag.get_vertex(uid) + if vertex is None: + return True + + vertex.delete() + self.linking_dag.save(confirm=True) + return None + + def print_tunneling_config(self, record_uid, pam_settings=None, config_uid=None): + if not pam_settings and not config_uid: + return + self.linking_dag.load() + vertex = self.linking_dag.get_vertex(record_uid) + content = get_vertex_content(vertex) + config_id = config_uid if config_uid else pam_settings.value[0].get('configUid') if pam_settings else None + if content and content.get('allowedSettings'): + allowed_settings = content['allowedSettings'] + logger.info(f"Settings configured for {record_uid}") + port_forwarding = f"Enabled" if allowed_settings.get('portForwards') else \ + "Disabled" + rotation = "Disabled" if (allowed_settings.get('rotation') and not allowed_settings['rotation']) else "Enabled" + logger.info(f"\tRotation: {rotation}") + logger.info(f"\tTunneling: {port_forwarding}") + + logger.info(f"Configuration: {config_id}") + if config_id is not None: + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + config_content = get_vertex_content(config_vertex) + if config_content and config_content.get('allowedSettings'): + config_allowed_settings = config_content['allowedSettings'] + config_port_forwarding = "Enabled" if ( + config_allowed_settings.get('portForwards')) else \ + "Disabled" + config_rotation = "Disabled" if (config_allowed_settings.get('rotation') and + not config_allowed_settings['rotation']) else \ + "Enabled" + logger.info(f"\tRotation: {config_rotation}") + logger.info(f"\tTunneling: {config_port_forwarding}") diff --git a/keepersdk-package/src/keepersdk/helpers/tunnel/tunnel_utils.py b/keepersdk-package/src/keepersdk/helpers/tunnel/tunnel_utils.py new file mode 100644 index 00000000..d774a5de --- /dev/null +++ b/keepersdk-package/src/keepersdk/helpers/tunnel/tunnel_utils.py @@ -0,0 +1,83 @@ +import json +import os +import logging +import requests + +from typing import Any, Dict, List, Optional + +from ... import errors, utils, crypto +from ...vault import vault_online +from ..keeper_dag.dag_crypto import generate_random_bytes +from ...authentication import endpoint + + +logger = logging.getLogger(__name__) + + +VERIFY_SSL = bool(os.environ.get("VERIFY_SSL", "TRUE") == "TRUE") + + +def get_config_uid(vault: vault_online.VaultOnline, encrypted_session_token: bytes, encrypted_transmission_key: bytes, record_uid: str) -> Optional[str]: + try: + rs = get_dag_leafs(vault, encrypted_session_token, encrypted_transmission_key, record_uid) + if not rs: + return None + else: + return rs[0].get('value', '') + except Exception as e: + logger.error(f"Error getting configuration: {e}") + return None + + +def get_dag_leafs(vault: vault_online.VaultOnline, encrypted_session_token: bytes, encrypted_transmission_key: bytes, record_id: str) -> Optional[List[Dict[str, Any]]]: + """ + POST a stringified JSON object to /api/dag/get_leafs on the KRouter + The object is: + { + vertex: string, + graphId: number + } + """ + krouter_host = f"https://{vault.keeper_auth.keeper_endpoint.get_router_server()}" + path = '/api/user/get_leafs' + + payload = { + 'vertex': record_id, + 'graphId': 0 + } + + try: + rs = requests.request('post', + krouter_host + path, + verify=VERIFY_SSL, + headers={ + 'TransmissionKey': utils.base64_url_encode(encrypted_transmission_key), + 'Authorization': f'KeeperUser {utils.base64_url_encode(encrypted_session_token)}' + }, + data=json.dumps(payload).encode('utf-8') + ) + except ConnectionError as e: + raise errors.KeeperApiError(-1, f"KRouter is not reachable on '{krouter_host}'. Error: ${e}") + except Exception as ex: + raise ex + + if rs.status_code == 200: + logger.debug("Found right host") + return rs.json() + else: + logger.warning("Looks like there is no such controller connected to the router.") + return None + + +def get_keeper_tokens(vault: vault_online.VaultOnline): + transmission_key = generate_random_bytes(32) + server_public_key = endpoint.SERVER_PUBLIC_KEYS[vault.keeper_auth.keeper_endpoint.server_key_id] + + if vault.keeper_auth.keeper_endpoint.server_key_id < 7: + encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) + else: + encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) + encrypted_session_token = crypto.encrypt_aes_v2( + utils.base64_url_decode(vault.keeper_auth.auth_context.session_token), transmission_key) + + return encrypted_session_token, encrypted_transmission_key, transmission_key diff --git a/keepersdk-package/src/keepersdk/utils.py b/keepersdk-package/src/keepersdk/utils.py index e8e3f70d..6cb42c8d 100644 --- a/keepersdk-package/src/keepersdk/utils.py +++ b/keepersdk-package/src/keepersdk/utils.py @@ -2,6 +2,7 @@ import logging import math import re +import sys import time from typing import Iterator, Callable, Optional from urllib.parse import urlparse @@ -291,3 +292,87 @@ def get_default_path(): default_path = Path.home().joinpath('.keeper') default_path.mkdir(parents=True, exist_ok=True) return default_path + + +def get_ssl_cert_file(): + """Get SSL certificate file path, preferring system CA store for corporate environments like Zscaler""" + import ssl + import platform + import certifi + import os + + # Allow user to override via environment variable + user_cert_file = os.getenv('KEEPER_SSL_CERT_FILE') + if user_cert_file: + if user_cert_file.lower() == 'system': + pass # Continue with system detection below + elif user_cert_file.lower() == 'certifi': + return certifi.where() + elif user_cert_file.lower() == 'none' or user_cert_file.lower() == 'false': + return False # Disable SSL verification + elif os.path.exists(user_cert_file): + return user_cert_file + else: + # Don't use logging here as it can interfere with main logging config + print(f"Warning: SSL cert file specified in KEEPER_SSL_CERT_FILE not found: {user_cert_file}", file=sys.stderr) + + # Try to use system CA store first for corporate environments + try: + # On macOS, try Homebrew certificates first (better for corporate environments like Zscaler) + if platform.system() == 'Darwin': + system_ca_paths = [ + '/opt/homebrew/etc/ca-certificates/cert.pem', # Homebrew CA bundle (best for Zscaler) + '/usr/local/etc/ssl/cert.pem', # Homebrew SSL (older location) + '/etc/ssl/cert.pem', # macOS system CA bundle + ] + for ca_path in system_ca_paths: + if os.path.exists(ca_path): + return ca_path + + # On Linux/Unix systems + elif platform.system() == 'Linux': + system_ca_paths = [ + '/etc/ssl/certs/ca-certificates.crt', # Debian/Ubuntu + '/etc/pki/tls/certs/ca-bundle.crt', # RHEL/CentOS + '/etc/ssl/ca-bundle.pem', # OpenSUSE + '/etc/ssl/cert.pem', # Generic + ] + for ca_path in system_ca_paths: + if os.path.exists(ca_path): + return ca_path + + # Try to get default SSL context locations + try: + default_locations = ssl.get_default_verify_paths() + if default_locations.cafile and os.path.exists(default_locations.cafile): + return default_locations.cafile + if default_locations.capath and os.path.exists(default_locations.capath): + return default_locations.capath + except: + pass + + except Exception: + pass + + # Fall back to certifi if system CA not available + return certifi.where() + +def ssl_aware_request(method, url, **kwargs): + """Make an SSL-aware HTTP request using system CA certificates when available""" + import requests + + # Only set verify if not already specified + if 'verify' not in kwargs: + cert_file = get_ssl_cert_file() + if cert_file is False: + kwargs['verify'] = False + elif cert_file: + kwargs['verify'] = cert_file + # If cert_file is None, let requests use its default + + return requests.request(method, url, **kwargs) + + +def ssl_aware_get(url, **kwargs): + """SSL-aware GET request using system CA certificates when available""" + return ssl_aware_request('GET', url, **kwargs) \ No newline at end of file