feature: added magic method for containment test of 'SSHKeyCollection' and magic method for setting the status of 'SSHKey'

This commit is contained in:
2026-01-20 12:14:53 -05:00
parent 67e645a3bd
commit 28bd322c87

View File

@@ -1,4 +1,5 @@
from re import Pattern as RegEx
from re import fullmatch as Match
from pathlib import Path, PurePath
from custtypes import ExecutedPath, IdlePath, VirtualPrivateServers, AnsibleScopes
from enum import Enum
@@ -214,11 +215,34 @@ class SSHKey:
return result
def update_status(self) -> None:
if isinstance(self.__value, tuple):
privkey_present = False
pubkey_present = False
for p in self.__value:
if "-----BEGIN OPENSSH PRIVATE KEY-----" in p.read_text():
privkey_present = True
else:
pubkey_present = True
if pubkey_present and privkey_present:
self.category = SSHKeyType.dual.name
elif pubkey_present or privkey_present:
if pubkey_present:
self.category = SSHKeyType.pubkey.name
if privkey_present:
self.category = SSHKeyType.privkey.name
elif isinstance(self.__value, ExecutedPath):
if "-----BEGIN OPENSSH PRIVATE KEY-----" in self.__value.read_text():
self.category = SSHKeyType.privkey.name
else:
self.category = SSHKeyType.pubkey.name
@property
def status(self) -> str | Never:
# @TODO analyze 'read' return value of this class's instance to check whether it is a private key,
# unless '__value' of this class's instance is a tuple, in which case it is a dual key
raise NotImplementedError
def status(self) -> str:
self.update_status()
return self.category
def reverse(self) -> Never:
if isinstance(self.__value, tuple):
@@ -378,6 +402,7 @@ class SSHKeyCollection(Sequence):
# print("branch1")
ssh_key._SSHKey__idx = 0
# print(ssh_key._SSHKey__idx)
ssh_key.update_status()
self.__indices = range(ssh_key._SSHKey__idx + 1)
self.__first = ssh_key
@@ -388,6 +413,7 @@ class SSHKeyCollection(Sequence):
# print("branch2.1")
ssh_key._SSHKey__idx = self.__last._SSHKey__idx + 1
# print(ssh_key._SSHKey__idx)
ssh_key.update_status()
self.__last._SSHKey__next = ssh_key
self.__last._SSHKey__next._SSHKey__prev = self.__last
self.__last = next(self.__last)
@@ -395,6 +421,7 @@ class SSHKeyCollection(Sequence):
# print("branch2.2")
ssh_key._SSHKey__idx = self.__first._SSHKey__idx + 1
# print(ssh_key._SSHKey__idx)
ssh_key.update_status()
self.__first._SSHKey__next = ssh_key
self.__first._SSHKey__next._SSHKey__prev = self.__first
self.__last = self.__first._SSHKey__next
@@ -483,10 +510,21 @@ class SSHKeyCollection(Sequence):
return self.__last
def __contains__(self) -> bool | Never:
raise NotImplementedError
def __contains__(self, value: ExecutedPath | str) -> bool:
if isinstance(value, ExecutedPath):
value = str(value)
def __missing__(self) -> Never:
is_contained = False
while self.__current is not None:
if str(self.__current._SSHKey__value) == value:
is_contained = True
break
self.__current = next(self.__current)
return is_contained
def __missing__(self, value: ExecutedPath | str) -> Never:
raise NotImplementedError
def __next__(self):
@@ -504,12 +542,13 @@ class SSHKeyCollection(Sequence):
raise NotImplementedError
# @TODO make sure to implement below method
def pull(self, query: RegEx | str = "*") -> Never:
def pull(self, query: RegEx | str = "*") -> None:
if isinstance(query, RegEx):
keypaths = self.__ssh_path.glob("*")
for p in keypaths:
if query.fullmatch(p.name):
if not Match("(known_hosts|authorized_keys|config).*", p.name):
self.append(p)
else:
continue
@@ -517,6 +556,7 @@ class SSHKeyCollection(Sequence):
keypaths = self.__ssh_path.glob(query)
for p in keypaths:
if not Match("(known_hosts|authorized_keys|config).*", p.name):
self.append(p)
def reverse(self) -> None | Never:
@@ -525,7 +565,7 @@ class SSHKeyCollection(Sequence):
def sort(self, key: Callable = (lambda e: e), reverse: bool = False) -> None | Never:
raise NotImplementedError
def __str__(self):
def __str__(self) -> str:
prefix = "[("
postfix = "|]"
@@ -551,10 +591,3 @@ class UserSSH:
self.password = password
self.fate = fate
sshkey_coll = SSHKeyCollection()
sshkey_coll.pull()
print(len(sshkey_coll))
print(sshkey_coll[0:len(sshkey_coll)])
for k in sshkey_coll:
print(k)