diff --git a/sshkey.py b/sshkey.py index a347aed..15721f5 100644 --- a/sshkey.py +++ b/sshkey.py @@ -37,8 +37,35 @@ class SSHKey: def __int__(self) -> int: return self.__idx + + def update_status(self) -> None: + if isinstance(self.__value, tuple): + privkey_present = False + pubkey_present = False + for p in self.__value: + if "-----BEGIN OPENSSH PRIVATE KEY-----" in p.read_text(): + privkey_present = True + else: + pubkey_present = True + + if pubkey_present and privkey_present: + self.category = SSHKeyType.dual.name.lower() + elif pubkey_present or privkey_present: + if pubkey_present: + self.category = SSHKeyType.pubkey.name.lower() + if privkey_present: + self.category = SSHKeyType.privkey.name.lower() + elif isinstance(self.__value, ExecutedPath): + if "-----BEGIN OPENSSH PRIVATE KEY-----" in self.__value.read_text(): + self.category = SSHKeyType.privkey.name.lower() + else: + self.category = SSHKeyType.pubkey.name.lower() + def __str__(self) -> str: - key_basename = Path(str(self.__value)).name + if isinstance(self.__value, tuple): + key_basename = Path(str(self.__value[0])).stem + "." + self.category + else: + key_basename = Path(str(self.__value)).name return "🔑" + key_basename def __repr__(self) -> str: @@ -107,41 +134,75 @@ class SSHKey: return self - def __add__(self, other: Self | ExecutedPath | str): - if isinstance(self.__value, tuple): - raise ValueError - + def __add__(self, other: Self | ExecutedPath | str) -> Self: if isinstance(other, (str, ExecutedPath)): - result = self.update(self.__value, other) + if isinstance(self.__value, tuple): + self.update(*self.__value, other) + else: + self.update(*self.__value, other) else: - if isinstance(other.__SSHKey__value, tuple): - raise ValueError + if isinstance(other._SSHKey__value, tuple): + if isinstance(self.__value, tuple): + self.update(*self.__value, *other._SSHKey__value) + else: + self.update(self.__value, *other._SSHKey__value) + else: + if isinstance(self.__value, tuple): + self.update(*self.__value, other._SSHKey__value) + else: + self.update(self.__value, other._SSHKey__value) - result = self.update(self.__value, other._SSHKey__value) + self.update_status() - return result + return self - def __radd__(self, other: Self | ExecutedPath | str): + def __radd__(self, other: Self | ExecutedPath | str) -> Self: if isinstance(self.__value, tuple): - raise ValueError - - if isinstance(other, (str, ExecutedPath)): - result = self.update(other, self.__value) + if isinstance(other, (ExecutedPath, str)): + other.update(other, *self.__value) + else: + if isinstance(other._SSHKey__value, tuple): + other.update(*other._SSHKey__value, *self.__value) + else: + other.update(other._SSHKey__value, *self.__value) else: - if isinstance(other.__SSHKey__value, tuple): - raise ValueError + if isinstance(other, (ExecutedPath, str)): + other.update(other, self.__value) + else: + if isinstance(other._SSHKey__value, tuple): + other.update(*other._SSHKey__value, self.__value) + else: + other.update(other._SSHKey__value, self.__value) - result = self.update(other._SSHKey__value, self.__value) + other.update_status() - return result + return self # @TODO write following 2 subtraction algorithms using 'set' data type conversion and methods - def __sub__(self, other: Self | ExecutedPath | str): + def __sub__(self, other: Self | ExecutedPath | str) -> Never: raise NotImplementedError - def __rsub__(self, other: Self | ExecutedPath | str): + def __rsub__(self, other: Self | ExecutedPath | str) -> Never: raise NotImplementedError + def __getitem__(self, key: int) -> ExecutedPath | str | Never: + if isinstance(self.__value, tuple): + return self.__value[key] + else: + raise KeyError + + def __setitem__(self, key: int, value: ExecutedPath | str) -> None: + if isinstance(self.__value, tuple): + new_entry = list(self.__value) + + if isinstance(value, str): + value = Path(value) + + new_entry[key] = value + self.__value = tuple(new_entry) + else: + raise KeyError + def replace(self, old: ExecutedPath | str | tuple[ExecutedPath | str] | list[ExecutedPath | str], new: ExecutedPath | str | tuple[ExecutedPath | str] | list[ExecutedPath | str]) -> Self | Never: if isinstance(old, str): old = Path(old) @@ -206,29 +267,6 @@ class SSHKey: return result - def update_status(self) -> None: - if isinstance(self.__value, tuple): - privkey_present = False - pubkey_present = False - for p in self.__value: - if "-----BEGIN OPENSSH PRIVATE KEY-----" in p.read_text(): - privkey_present = True - else: - pubkey_present = True - - if pubkey_present and privkey_present: - self.category = SSHKeyType.dual.name.lower() - elif pubkey_present or privkey_present: - if pubkey_present: - self.category = SSHKeyType.pubkey.name.lower() - if privkey_present: - self.category = SSHKeyType.privkey.name.lower() - elif isinstance(self.__value, ExecutedPath): - if "-----BEGIN OPENSSH PRIVATE KEY-----" in self.__value.read_text(): - self.category = SSHKeyType.privkey.name.lower() - else: - self.category = SSHKeyType.pubkey.name.lower() - @property def status(self) -> str: self.update_status() @@ -265,7 +303,7 @@ def stream_unequal(s1, s2) -> bool | Never: # @TODO create unit tests for below class class SSHKeyCollection(Sequence): - __user_path: ExecutedPath = USER_PATH + __user_path: ExecutedPath = USER_PATH() __ssh_path: ExecutedPath = __user_path / ".ssh" def __init__(self): @@ -502,6 +540,8 @@ class SSHKeyCollection(Sequence): return self.__last def __contains__(self, value: ExecutedPath | str) -> bool: + self.__current = self.__first + if isinstance(value, ExecutedPath): value = str(value) @@ -522,7 +562,8 @@ class SSHKeyCollection(Sequence): self.__current = next(self.__current) if self.__current is not None: return self.__current - raise StopIteration + else: + raise StopIteration def __iter__(self) -> Self | Never: self.__current = self.__first @@ -562,7 +603,7 @@ class SSHKeyCollection(Sequence): self.__current = self.__first - concat = lambda s: str(s)[1:] + ", " + concat = lambda s: str(s) + ", " content = str() count = 0 while self.__current is not None: @@ -573,13 +614,17 @@ class SSHKeyCollection(Sequence): return prefix + content + postfix - def publish(self, category: SSHKeyType = SSHKeyType.pubkey.name.lower(), pref: int | None = None, datatype = dict): + def index(self, item: str | ExecutedPath) -> Never: + raise NotImplementedError + + def publish(self, category: SSHKeyType | str = SSHKeyType.pubkey, pref: int | None = None, datatype = dict): privkey = list() pubkey = list() self.__current = self.__first if datatype == list: while self.__current is not None: + # print(self.__current) if self.__current.category == SSHKeyType.privkey.name.lower(): privkey.append(self.__current._SSHKey__value) elif self.__current.category == SSHKeyType.pubkey.name.lower(): @@ -587,16 +632,18 @@ class SSHKeyCollection(Sequence): elif self.__current.category == SSHKeyType.dual.name.lower(): privkey.append(self.__current._SSHKey__value[0]) pubkey.append(self.__current._SSHKey__value[1]) - self.__current = next(self.__first) + self.__current = next(self.__current) + # print("publish running...") if pref is None: preference = gamble(range(len(privkey))) else: preference = pref - if category == SSHKeyType.pubkey.name.lower(): + # print(category) + if category.name.lower() == SSHKeyType.pubkey.name.lower(): return pubkey - elif category == SSHKeyType.privkey.name.lower(): + elif category.name.lower() == SSHKeyType.privkey.name.lower(): return (privkey, preference) else: return (privkey, pubkey, preference) @@ -604,6 +651,7 @@ class SSHKeyCollection(Sequence): result = dict() while self.__current is not None: + # print(self.__current) if self.__current.category == SSHKeyType.privkey.name.lower(): privkey.append(str(self.__current._SSHKey__value)) elif self.__current.category == SSHKeyType.pubkey.name.lower(): @@ -611,16 +659,17 @@ class SSHKeyCollection(Sequence): elif self.__current.category == SSHKeyType.dual.name.lower(): privkey.append(str(self.__current._SSHKey__value[0])) pubkey.append(self.__current._SSHKey__value[1].read_text()) - self.__current = next(self.__first) + self.__current = next(self.__current) + # print("publish running...") - if category == SSHKeyType.pubkey.name.lower(): + if category.name.lower() == SSHKeyType.pubkey.name.lower(): result["ssh_authorized_keys"]: list[str] = pubkey - if category == SSHKeyType.privkey.name.lower(): + if category.name.lower() == SSHKeyType.privkey.name.lower(): result["ssh_private_key_paths"]: list[str] = privkey result["ssh_private_key_path_pref"]: int = pref if pref is not None else gamble(range(len(privkey))) - if category == SSHKeyType.dual.name.lower(): + if category.name.lower() == SSHKeyType.dual.name.lower(): result["ssh_authorized_keys"]: list[str] = pubkey result["ssh_private_key_paths"]: list[str] = privkey result["ssh_private_key_path_pref"]: int = pref if pref is not None else gamble(range(len(privkey)))