Commit 9524b285 authored by sfritschi's avatar sfritschi
Browse files

Refactoring

parent 668e17d7
......@@ -12,7 +12,7 @@ def main():
nthreads = int(sys.argv[1])
print("Using %d thread(s)" % nthreads)
basenet = netflow.load_network_from('../netflow/network/network.h5', isGenerator=True)
basenet = netflow.load_network_from('../netflow/network/network.h5', isBaseNet=True)
print("Network statistics:")
pore_throat_counts = [len(pore.throats) for pore in basenet.pores]
print("Max. number of throats per pore: %d" % max(pore_throat_counts))
......
......@@ -62,7 +62,7 @@ class Network:
throatR = []
def __init__(self, lb: List[float] = [], ub: List[float] = [], \
pores: List[Pore] = None, throats: List[Throat] = None, \
Lmax: float = 0.0, label: str = None):
Lmax: float = 0.0, label: str = None, isBaseNet: bool = False):
# lower/upper bounds for network volume
self.lb = lb.copy(); self.ub = ub.copy()
# sets with pores and throats
......@@ -77,6 +77,7 @@ class Network:
self.label = datetime.today().isoformat(sep=' ',timespec='seconds')
else:
self.label = label
self.isBaseNet = isBaseNet
# Lengths of network domain
self.L = [ub-lb for lb,ub in zip(self.lb,self.ub)]
def __repr__(self):
......@@ -663,6 +664,9 @@ def generate_dendrogram(basenet: Network, targetsize: List[int], \
# check valid number of threads
if (nthreads < 1):
raise ValueError('Number of threads must be >= 1!')
# check network provided was initialized properly
if (not basenet.isBaseNet):
raise ValueError('Provided network was not initialized as base network!')
# pores, distributed based on dendrogram of basenet
if (not mute): print("distributing pores...")
......@@ -1159,14 +1163,14 @@ def save_network_to(filename: str, network: Network):
data=wrk)
def load_network_from(filename: str, isGenerator: bool = False) -> Network:
def load_network_from(filename: str, isBaseNet: bool) -> Network:
"""Load pore network from hdf5 file."""
import h5py
import numpy as np
f = h5py.File(filename, 'r')
# global network properties
network = Network(label=f['network_label'][0].decode('utf-8'),
lb=list(f['lb']), ub=list(f['ub']), Lmax=f['Lmax'][0])
lb=list(f['lb']), ub=list(f['ub']), Lmax=f['Lmax'][0], isBaseNet=isBaseNet)
# pores
pid = np.array(f['pores/id'])
......@@ -1188,7 +1192,7 @@ def load_network_from(filename: str, isGenerator: bool = False) -> Network:
# pick appropriate connector method if generation algorithm is used
connector_method = network.connect_pores
if isGenerator:
if isBaseNet:
connector_method = network.connect_pores_generator
# throats
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment