feature: added magic method for containment test of 'SSHKeyCollection' and magic method for setting the status of 'SSHKey'

This commit is contained in:
2026-01-20 12:14:53 -05:00
parent 67e645a3bd
commit 28bd322c87

View File

@@ -1,4 +1,5 @@
from re import Pattern as RegEx from re import Pattern as RegEx
from re import fullmatch as Match
from pathlib import Path, PurePath from pathlib import Path, PurePath
from custtypes import ExecutedPath, IdlePath, VirtualPrivateServers, AnsibleScopes from custtypes import ExecutedPath, IdlePath, VirtualPrivateServers, AnsibleScopes
from enum import Enum from enum import Enum
@@ -214,11 +215,34 @@ class SSHKey:
return result return result
def update_status(self) -> None:
if isinstance(self.__value, tuple):
privkey_present = False
pubkey_present = False
for p in self.__value:
if "-----BEGIN OPENSSH PRIVATE KEY-----" in p.read_text():
privkey_present = True
else:
pubkey_present = True
if pubkey_present and privkey_present:
self.category = SSHKeyType.dual.name
elif pubkey_present or privkey_present:
if pubkey_present:
self.category = SSHKeyType.pubkey.name
if privkey_present:
self.category = SSHKeyType.privkey.name
elif isinstance(self.__value, ExecutedPath):
if "-----BEGIN OPENSSH PRIVATE KEY-----" in self.__value.read_text():
self.category = SSHKeyType.privkey.name
else:
self.category = SSHKeyType.pubkey.name
@property @property
def status(self) -> str | Never: def status(self) -> str:
# @TODO analyze 'read' return value of this class's instance to check whether it is a private key, self.update_status()
# unless '__value' of this class's instance is a tuple, in which case it is a dual key
raise NotImplementedError return self.category
def reverse(self) -> Never: def reverse(self) -> Never:
if isinstance(self.__value, tuple): if isinstance(self.__value, tuple):
@@ -378,6 +402,7 @@ class SSHKeyCollection(Sequence):
# print("branch1") # print("branch1")
ssh_key._SSHKey__idx = 0 ssh_key._SSHKey__idx = 0
# print(ssh_key._SSHKey__idx) # print(ssh_key._SSHKey__idx)
ssh_key.update_status()
self.__indices = range(ssh_key._SSHKey__idx + 1) self.__indices = range(ssh_key._SSHKey__idx + 1)
self.__first = ssh_key self.__first = ssh_key
@@ -388,6 +413,7 @@ class SSHKeyCollection(Sequence):
# print("branch2.1") # print("branch2.1")
ssh_key._SSHKey__idx = self.__last._SSHKey__idx + 1 ssh_key._SSHKey__idx = self.__last._SSHKey__idx + 1
# print(ssh_key._SSHKey__idx) # print(ssh_key._SSHKey__idx)
ssh_key.update_status()
self.__last._SSHKey__next = ssh_key self.__last._SSHKey__next = ssh_key
self.__last._SSHKey__next._SSHKey__prev = self.__last self.__last._SSHKey__next._SSHKey__prev = self.__last
self.__last = next(self.__last) self.__last = next(self.__last)
@@ -395,6 +421,7 @@ class SSHKeyCollection(Sequence):
# print("branch2.2") # print("branch2.2")
ssh_key._SSHKey__idx = self.__first._SSHKey__idx + 1 ssh_key._SSHKey__idx = self.__first._SSHKey__idx + 1
# print(ssh_key._SSHKey__idx) # print(ssh_key._SSHKey__idx)
ssh_key.update_status()
self.__first._SSHKey__next = ssh_key self.__first._SSHKey__next = ssh_key
self.__first._SSHKey__next._SSHKey__prev = self.__first self.__first._SSHKey__next._SSHKey__prev = self.__first
self.__last = self.__first._SSHKey__next self.__last = self.__first._SSHKey__next
@@ -483,10 +510,21 @@ class SSHKeyCollection(Sequence):
return self.__last return self.__last
def __contains__(self) -> bool | Never: def __contains__(self, value: ExecutedPath | str) -> bool:
raise NotImplementedError if isinstance(value, ExecutedPath):
value = str(value)
def __missing__(self) -> Never: is_contained = False
while self.__current is not None:
if str(self.__current._SSHKey__value) == value:
is_contained = True
break
self.__current = next(self.__current)
return is_contained
def __missing__(self, value: ExecutedPath | str) -> Never:
raise NotImplementedError raise NotImplementedError
def __next__(self): def __next__(self):
@@ -504,12 +542,13 @@ class SSHKeyCollection(Sequence):
raise NotImplementedError raise NotImplementedError
# @TODO make sure to implement below method # @TODO make sure to implement below method
def pull(self, query: RegEx | str = "*") -> Never: def pull(self, query: RegEx | str = "*") -> None:
if isinstance(query, RegEx): if isinstance(query, RegEx):
keypaths = self.__ssh_path.glob("*") keypaths = self.__ssh_path.glob("*")
for p in keypaths: for p in keypaths:
if query.fullmatch(p.name): if query.fullmatch(p.name):
if not Match("(known_hosts|authorized_keys|config).*", p.name):
self.append(p) self.append(p)
else: else:
continue continue
@@ -517,6 +556,7 @@ class SSHKeyCollection(Sequence):
keypaths = self.__ssh_path.glob(query) keypaths = self.__ssh_path.glob(query)
for p in keypaths: for p in keypaths:
if not Match("(known_hosts|authorized_keys|config).*", p.name):
self.append(p) self.append(p)
def reverse(self) -> None | Never: def reverse(self) -> None | Never:
@@ -525,7 +565,7 @@ class SSHKeyCollection(Sequence):
def sort(self, key: Callable = (lambda e: e), reverse: bool = False) -> None | Never: def sort(self, key: Callable = (lambda e: e), reverse: bool = False) -> None | Never:
raise NotImplementedError raise NotImplementedError
def __str__(self): def __str__(self) -> str:
prefix = "[(" prefix = "[("
postfix = "|]" postfix = "|]"
@@ -551,10 +591,3 @@ class UserSSH:
self.password = password self.password = password
self.fate = fate self.fate = fate
sshkey_coll = SSHKeyCollection()
sshkey_coll.pull()
print(len(sshkey_coll))
print(sshkey_coll[0:len(sshkey_coll)])
for k in sshkey_coll:
print(k)