diff --git a/pyasic/network/__init__.py b/pyasic/network/__init__.py index fa43a783..85db07b8 100644 --- a/pyasic/network/__init__.py +++ b/pyasic/network/__init__.py @@ -21,80 +21,91 @@ from typing import AsyncIterator, List, Union from pyasic import settings from pyasic.miners.miner_factory import AnyMiner, miner_factory -from pyasic.network.net_range import MinerNetworkRange class MinerNetwork: """A class to handle a network containing miners. Handles scanning and gets miners via [`MinerFactory`][pyasic.miners.miner_factory.MinerFactory]. Parameters: - ip_addr: ### An IP address, range of IP addresses, or a list of IPs - * Takes a single IP address as an `ipadddress.ipaddress()` or a string - * Takes a string formatted as: - ```f"{ip_range_1_start}-{ip_range_1_end}, {ip_address_1}, {ip_range_2_start}-{ip_range_2_end}, {ip_address_2}..."``` - * Also takes a list of strings or `ipaddress.ipaddress` formatted as: - ```[{ip_address_1}, {ip_address_2}, {ip_address_3}, ...]``` - mask: A subnet mask to use when constructing the network. Only used if `ip_addr` is a single IP. - Defaults to /24 (255.255.255.0 or 0.0.0.255) + hosts: A list of `ipaddress.IPv4Address` to be used when scanning. """ - def __init__( - self, - ip_addr: Union[str, List[str], None] = None, - mask: Union[str, int, None] = None, - ) -> None: - self.network = None - self.ip_addr = ip_addr - self.connected_miners = {} - if isinstance(mask, str): - if mask.startswith("/"): - mask = mask.replace("/", "") - self.mask = mask - self.network = self.get_network() + def __init__(self, hosts: List[ipaddress.IPv4Address]): + self.hosts = hosts def __len__(self): - return len([item for item in self.get_network().hosts()]) + return len(self.hosts) - def __repr__(self): - return str(self.network) + @classmethod + def from_list(cls, addresses: list): + """Parse a list of address constructors into a MinerNetwork. - def hosts(self): - for x in self.network.hosts(): - yield x + Parameters: + addresses: A list of address constructors, such as `["10.1-2.1.1-50", "10.4.1-2.1-50"]`. + """ + hosts = [] + for address in addresses: + hosts = [*hosts, *cls.from_address(address)] + return sorted(list(set(hosts))) - def get_network(self) -> ipaddress.ip_network: - """Get the network using the information passed to the MinerNetwork or from cache. + @classmethod + def from_address(cls, address: str): + """Parse an address constructor into a MinerNetwork. - Returns: - The proper network to be able to scan. + Parameters: + address: An address constructor, such as `"10.1-2.1.1-50"`. + """ + octets = address.split(".") + if len(octets) > 4: + raise ValueError("Too many octets in IP constructor.") + if len(octets) < 4: + raise ValueError("Too few octets in IP constructor.") + return cls.from_octets(*octets) + + @classmethod + def from_octets(cls, oct_1: str, oct_2: str, oct_3: str, oct_4: str): + """Parse 4 octet constructors into a MinerNetwork. + + Parameters: + oct_1: An octet constructor, such as `"10"`. + oct_2: An octet constructor, such as `"1-2"`. + oct_3: An octet constructor, such as `"1"`. + oct_4: An octet constructor, such as `"1-50"`. """ - # if we have a network cached already, use that - if self.network is not None: - return self.network + hosts = [] - # if there is no IP address passed, default to 192.168.1.0 - if not self.ip_addr: - self.ip_addr = "192.168.1.0" - if isinstance(self.ip_addr, list): - self.network = MinerNetworkRange(",".join(self.ip_addr)) - elif "-" in self.ip_addr: - self.network = MinerNetworkRange(self.ip_addr) - else: - # if there is no subnet mask passed, default to /24 - if not self.mask: - subnet_mask = "24" - # if we do have a mask passed, use that - else: - subnet_mask = str(self.mask) + oct_1_val_start, oct_1_start, oct_1_end = compute_oct_range(oct_1) + for oct_1_idx in range((abs(oct_1_end - oct_1_start)) + 1): + oct_1_val = str(oct_1_idx + oct_1_start) - # save the network and return it - self.network = ipaddress.ip_network( - f"{self.ip_addr}/{subnet_mask}", strict=False - ) + oct_2_val_start, oct_2_start, oct_2_end = compute_oct_range(oct_2) + for oct_2_idx in range((abs(oct_2_end - oct_2_start)) + 1): + oct_2_val = str(oct_2_idx + oct_2_start) - logging.debug(f"{self} - (Get Network) - Found network") - return self.network + oct_3_val_start, oct_3_start, oct_3_end = compute_oct_range(oct_3) + for oct_3_idx in range((abs(oct_3_end - oct_3_start)) + 1): + oct_3_val = str(oct_3_idx + oct_3_start) + + oct_4_val_start, oct_4_start, oct_4_end = compute_oct_range(oct_4) + for oct_4_idx in range((abs(oct_4_end - oct_4_start)) + 1): + oct_4_val = str(oct_4_idx + oct_4_start) + + hosts.append( + ipaddress.ip_address( + ".".join([oct_1_val, oct_2_val, oct_3_val, oct_4_val]) + ) + ) + return sorted(hosts) + + @classmethod + def from_subnet(cls, subnet: str): + """Parse a subnet into a MinerNetwork. + + Parameters: + subnet: A subnet string, such as `"10.0.0.1/24"`. + """ + return list(ipaddress.ip_network(subnet, strict=False).hosts()) async def scan_network_for_miners(self) -> List[AnyMiner]: """Scan the network for miners, and return found miners as a list. @@ -102,8 +113,6 @@ class MinerNetwork: Returns: A list of found miners. """ - # get the network - local_network = self.get_network() logging.debug(f"{self} - (Scan Network For Miners) - Scanning") # clear cached miners @@ -111,7 +120,7 @@ class MinerNetwork: limit = asyncio.Semaphore(settings.get("network_scan_threads", 300)) miners = await asyncio.gather( - *[self.ping_and_get_miner(host, limit) for host in local_network.hosts()] + *[self.ping_and_get_miner(host, limit) for host in self.hosts] ) # remove all None from the miner list @@ -133,15 +142,12 @@ class MinerNetwork: # get the current event loop loop = asyncio.get_event_loop() - # get the network - local_network = self.get_network() - # create a list of scan tasks limit = asyncio.Semaphore(settings.get("network_scan_threads", 300)) miners = asyncio.as_completed( [ loop.create_task(self.ping_and_get_miner(host, limit)) - for host in local_network.hosts() + for host in self.hosts ] ) for miner in miners: @@ -245,3 +251,19 @@ async def ping_and_get_miner( logging.warning(f"{str(ip)}: Ping And Get Miner Exception: {e}") raise ConnectionRefusedError return + + +def compute_oct_range(octet: str) -> tuple: + octet_split = octet.split("-") + octet_start = int(octet_split[0]) + octet_end = None + try: + octet_end = int(octet_split[1]) + except IndexError: + pass + if octet_end is None: + octet_end = int(octet_start) + 1 + + octet_val_start = min([octet_start, octet_end]) + + return octet_val_start, octet_start, octet_end diff --git a/pyasic/network/net_range.py b/pyasic/network/net_range.py deleted file mode 100644 index a487d018..00000000 --- a/pyasic/network/net_range.py +++ /dev/null @@ -1,56 +0,0 @@ -# ------------------------------------------------------------------------------ -# Copyright 2022 Upstream Data Inc - -# - -# Licensed under the Apache License, Version 2.0 (the "License"); - -# you may not use this file except in compliance with the License. - -# You may obtain a copy of the License at - -# - -# http://www.apache.org/licenses/LICENSE-2.0 - -# - -# Unless required by applicable law or agreed to in writing, software - -# distributed under the License is distributed on an "AS IS" BASIS, - -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - -# See the License for the specific language governing permissions and - -# limitations under the License. - -# ------------------------------------------------------------------------------ - -import ipaddress -from typing import Union - - -class MinerNetworkRange: - """A MinerNetwork that takes a range of IP addresses. - - Parameters: - ip_range: ## A range of IP addresses to put in the network, or a list of IPs - * Takes a string formatted as: - ```f"{ip_range_1_start}-{ip_range_1_end}, {ip_address_1}, {ip_range_2_start}-{ip_range_2_end}, {ip_address_2}..."``` - * Also takes a list of strings or `ipaddress.ipaddress` formatted as: - ```[{ip_address_1}, {ip_address_2}, {ip_address_3}, ...]``` - """ - - def __init__(self, ip_range: Union[str, list]): - self.host_ips = [] - if isinstance(ip_range, str): - ip_ranges = ip_range.replace(" ", "").split(",") - for item in ip_ranges: - if "-" in item: - start, end = item.split("-") - start_ip = ipaddress.ip_address(start) - end_ip = ipaddress.ip_address(end) - networks = ipaddress.summarize_address_range(start_ip, end_ip) - for network in networks: - self.host_ips.append(network.network_address) - for host in network.hosts(): - if host not in self.host_ips: - self.host_ips.append(host) - if network.broadcast_address not in self.host_ips: - self.host_ips.append(network.broadcast_address) - else: - self.host_ips.append(ipaddress.ip_address(item)) - elif isinstance(ip_range, list): - self.host_ips = [ipaddress.ip_address(ip_str) for ip_str in ip_range] - - def hosts(self): - for x in self.host_ips: - yield x