Source code for steps.API_1.utilities.geom_decompose

####################################################################################
#
#    STEPS - STochastic Engine for Pathway Simulation
#    Copyright (C) 2007-2023 Okinawa Institute of Science and Technology, Japan.
#    Copyright (C) 2003-2006 University of Antwerp, Belgium.
#    
#    See the file AUTHORS for details.
#    This file is part of STEPS.
#    
#    STEPS is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License version 3,
#    as published by the Free Software Foundation.
#    
#    STEPS is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#    GNU General Public License for more details.
#    
#    You should have received a copy of the GNU General Public License
#    along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#################################################################################   
###
from __future__ import print_function

from numpy import *
import math
import warnings

from steps.API_1.geom import UNKNOWN_TET
from steps.API_1.geom import INDEX_DTYPE

################################################################################

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

################################################################################

[docs]def binTetsByAxis(mesh, nbins, axis = -1): """ Bin tetrahedrons of the mesh along a specific axis. This function is now deprecated, use linearPartition() instead. Parameters: * mesh STEPS Tetmesh object * nbins Number of bins along the axis * axis The partioning axis(Option -1: longest axis, 0:x, 1:y, 2:z) Return: Tetrahedron partition list for parallel TetOpsplit solver """ warnings.warn("This function is deprecated, use linearPartition() instead.") selected_axis = axis max_xyz = mesh.getBoundMax() min_xyz = mesh.getBoundMin() if axis == -1: # search for axis that with maximum distance dist = array(max_xyz) - array(min_xyz) selected_axis = dist.argmax() max = max_xyz[selected_axis] min = min_xyz[selected_axis] spacing = linspace(min, max, nbins + 1) centers = [0.0] * mesh.ntets for i in range(mesh.ntets): baryc = mesh.getTetBarycenter(i) centers[i] = baryc[selected_axis] belongs = digitize(centers, spacing) belongs -= 1 return belongs
################################################################################ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # ################################################################################
[docs]def linearPartition(mesh, partition_info): """ Partition the mesh based on the partition info. The partition_info is a list [xbin, ybin, zbin] where each element is the number of bins for the axis. Parameters: * mesh STEPS Tetmesh object * partition_info a list [xbin, ybin, zbin] describing the binning requirement of each axis Return: Tetrahedron partition list for parallel TetOpsplit solver """ assert(len(partition_info)==3) bmax = mesh.getBoundMax() bmin = mesh.getBoundMin() dx = (bmax[0]-bmin[0])/partition_info[0] dy = (bmax[1]-bmin[1])/partition_info[1] dz = (bmax[2]-bmin[2])/partition_info[2] part=zeros(mesh.ntets, dtype=INDEX_DTYPE) for tet in range(mesh.ntets): idx=0 baryc = mesh.getTetBarycenter(tet) z= bmin[2] zidx=0 while(zidx<partition_info[2]): y = bmin[1] yidx=0 while(yidx<partition_info[1]): x=bmin[0] xidx=0 while(xidx<partition_info[0]): if baryc[2] >= z and baryc[2] < z+dz and baryc[1] >= y and baryc[1] < y+dy and baryc[0] >= x and baryc[0] < x+dx: part[tet] = idx x+=dx xidx+=1 idx+=1 y+=dy yidx+=1 z+=dz zidx+=1 return list(part)
################################################################################ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # ################################################################################
[docs]def partitionTris(mesh, tet_partitions, tri_list): """ Partition trangles according to partitioning information of their attached tetrahedrons. Parameters: * mesh STEPS Tetmesh object * tet_partitions List of partioning for each tetrahedron in the mesh, generated by linearPartition function or third-party software * tri_list List of triangles that require partitioning Return: Triangle partition list for parallel TetOpsplit solver """ tri_partitions = {} for tri in tri_list: neigh_tets = mesh.getTriTetNeighb(tri) if neigh_tets[0] == UNKNOWN_TET and neigh_tets[0] == UNKNOWN_TET: print("Triangle ", tri, " has no attatched tetrahedron, which is unlikely. Please check your mesh.\n") continue if neigh_tets[0] == UNKNOWN_TET: tri_partitions[tri] = tet_partitions[neigh_tets[1]] continue if neigh_tets[1] == UNKNOWN_TET: tri_partitions[tri] = tet_partitions[neigh_tets[0]] continue if tet_partitions[neigh_tets[0]] == tet_partitions[neigh_tets[1]]: tri_partitions[tri] = tet_partitions[neigh_tets[0]] continue print("Neighbor tetrahedrons of triangle ", tri, " are assigned to different hosts, try to rearrange hosts for them.\n") tri_partitions[tri] = tet_partitions[neigh_tets[0]] tet_partitions[neigh_tets[1]] = tet_partitions[neigh_tets[0]] for tri in tri_list: neigh_tets = mesh.getTriTetNeighb(tri) for neigh_tet in neigh_tets: if neigh_tet == UNKNOWN_TET: continue if tet_partitions[neigh_tet] != tri_partitions[tri]: raise Exception("Patch triangle %i and its compartment tet are assigned to different processes." % (tri)) return tri_partitions
################################################################################ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # ################################################################################
[docs]def getTetPartitionTable(partitions): """ Convert a [tet0_host, tet1_host, ...] partitioning list to a {host0:[tet0, tet1, ...], ...} table. Parameters: * partitions Partitioning list for tetrahedrons in the format of [tet0_host, tet1_host, ...] Return: A dictionary in the format of {host0:[tet0, tet1, ...], ...} """ part_table = {} for element in range(len(partitions)): if partitions[element] not in part_table.keys(): part_table[partitions[element]] = [] part_table[partitions[element]].append(element) return part_table
[docs]def getTriPartitionTable(partitions): """ Convert a {tri0:tri0_host, tri1:tri1_host, ...} partitioning data to a {host0:[tri0, tri1, ...], ...} table. Parameters: * partitions Partitioning dictionary for triangles in the format of {tri0:tri0_host, tri1:tri1_host, ...} Return: A dictionary in the format of {host0:[tri0, tri1, ...], ...} """ part_table = {} for element in partitions: if partitions[element] not in part_table.keys(): part_table[partitions[element]] = [] part_table[partitions[element]].append(element) return part_table
[docs]def validatePartition(mesh, tet_partitions, tri_partitions = {}): """ Validate the partitioning of the mesh. Parameters: * mesh STEPS Tetmesh object * tet_partitions Partition list for tetrahedrons * tri_partitions Partition list for triangles (Optional) Return: None """ print("Validation starts.") tet_part_table = {} for tet in range(len(tet_partitions)): if tet_partitions[tet] not in tet_part_table.keys(): tet_part_table[tet_partitions[tet]] = [] tet_part_table[tet_partitions[tet]].append(tet) tri_part_table = {} for tri in tri_partitions: if tri_partitions[tri] not in tri_part_table.keys(): tri_part_table[tri_partitions[tri]] = [] tri_part_table[tri_partitions[tri]].append(tri) # validate if elements in each tet partition are connected for part in tet_part_table.values(): for tet in part: neighb_tets = mesh.getTetTetNeighb(tet) neighb_in_part = False for n_tet in neighb_tets: if n_tet == UNKNOWN_TET: continue if n_tet in part: neighb_in_part = True break if not neighb_in_part: print("Tetrahedron %i has no neighbor in its partition. This is unusual but still acceptable." % (tet)) for tri in tri_partitions: neigh_tets = mesh.getTriTetNeighb(tri) for neigh_tet in neigh_tets: if neigh_tet == UNKNOWN_TET: continue if tet_partitions[neigh_tet] != tri_partitions[tri]: raise Exception("Patch triangle %i and its compartment tet are assigned to different processes." % (tri)) print("Validation completed.")
################################################################################ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # ################################################################################ def printHostDistStat(tet_partitions = [], tri_partitions = {}, wmvol_partitions = []): print("Warnning: This function has been renamed to printPartitionStat(tet_partitions, tri_partitions, wmvol_partitions).") printPartitionStat(tet_partitions, tri_partitions, wmvol_partitions) ################################################################################ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # ################################################################################
[docs]def printPartitionStat(tet_partitions = [], tri_partitions = {}, wmvol_partitions = [], mesh = None): """ Print out partitioning stastics. Parameters: * tet_partitions Partition list for tetrahedrons * tri_partitions Partition list for triangles (Optional) * wmvol_partitions Partition list for well-mixed volumes (Optional) * mesh STEPS Tetmesh object (Optional) Return: (If mesh is provided) tet_stats, tri_stats, wm_stats, num_hosts, min_degree, max_degree, mean_degree [tet/tri/wm]_stats contains the number of tetrahedrons/triangles/well-mixed volumes in each hosting process, num_hosts provide the number of hosting processes, [min/max/mean]_degree provides the minimum/maximum/average connectivity degree of the partitioning """ tet_stats = [] for host in tet_partitions: if host >= len(tet_stats): tet_stats.extend([0] * (host - len(tet_stats) + 1)) tet_stats[host] += 1 print("Total number of assigned tets: ", len(tet_partitions)) if tet_stats != []: print("Distribution: ",) sum = 0 for h in tet_stats: print(h,) sum += h print("") print("Sum: ", sum) tri_stats = [] for tri_id in tri_partitions.keys(): host = tri_partitions[tri_id] if host >= len(tri_stats): tri_stats.extend([0] * (host - len(tri_stats) + 1)) tri_stats[host] += 1 print("Total number of assigned tris: ", len(tri_partitions)) if tri_stats != []: print("Distribution: ",) sum = 0 for h in tri_stats: sum += h print(h,) print("") print("Sum: ", sum) wm_stats = [] for host in wmvol_partitions: if host >= len(wmvol_partitions): wmvol_partitions.extend([0] * (host - len(wmvol_partitions) + 1)) wmvol_partitions[host] += 1 print("Total number of assigned well-mixed volumes: ", len(wmvol_partitions)) if wm_stats != []: print("WMVol Distribution: ",) sum = 0 for h in wm_stats: sum += h print(h,) print("") print("Sum: ", sum) if mesh is not None: partition_neighbors = {} for tet in range(mesh.ntets): tet_part = tet_partitions[tet] if tet_part not in partition_neighbors.keys(): partition_neighbors[tet_part] = set() neighbor_tets = mesh.getTetTetNeighb(tet) for neighb in neighbor_tets: if neighb == UNKNOWN_TET: continue neighb_tet_part = tet_partitions[neighb] if neighb_tet_part != tet_part: partition_neighbors[tet_part].add(neighb_tet_part) host_degrees = [] for neighbs in partition_neighbors.values(): host_degrees.append(len(neighbs)) print("Number of partitions: ", len(tet_stats)) print("Min Tet Partition Degree: ", min(host_degrees)) print("Max Tet Partition Degree: ", max(host_degrees)) print("Mean Tet Partition Degree: ", mean(host_degrees)) print("") return tet_stats, tri_stats, wm_stats, len(tet_stats), min(host_degrees), max(host_degrees), mean(host_degrees)
[docs]def isPointInCylinder(cyl_p0, cyl_p1, test_pnt, scale): """ This function is deprecated, use isPointInTruncatedCone() instead. """ warnings.warn("This function is deprecated, use isPointInTruncatedCone() instead.") return isPointInTruncatedCone(cyl_p0, cyl_p1, test_pnt, scale)
################################################################################ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # ################################################################################
[docs]def isPointInTruncatedCone(cyl_p0, cyl_p1, test_pnt, scale): """ Check if a scaled point is inside a truncated cone in 3D. The truncated cone is defined by an axis from the start cycle cyl_p0 to the end cycle cyl_p1. If cyl_p0 and cyl_p0 are from .swc morphology, and test_pnt is from a steps.geom.Tetmesh generated from the .swc data and imported with an importing scale of 1e-6 (micrometer), then the scale here is 1e-6. Arguements: * cyl_p0 coordinates and diameter of the start cycle, in the form of [x, y, z, d] * cyl_p1 coordinates and diameter of the end cycle, in the form of [x, y, z, d] * test_pnt coordinate of the test point * scale the scale between measurements of the truncated cone and the test point Return: -1 if point is outside the truncated cone, distance squared from truncated cone axis if point is inside. Modified from: CylTest_CapsFirst of Greg James - gjames@NVIDIA.com http://www.flipcode.com/archives/Fast_Point-In-Cylinder_Test.shtml Original Lisc: Free code - no warranty & no money back. Use it all you want """ dx = (cyl_p1[0] - cyl_p0[0]) * scale dy = (cyl_p1[1] - cyl_p0[1]) * scale dz = (cyl_p1[2] - cyl_p0[2]) * scale cyl_lengthsq = dx**2 + dy**2 + dz**2 pdx0 = test_pnt[0] - cyl_p0[0] * scale pdy0 = test_pnt[1] - cyl_p0[1] * scale pdz0 = test_pnt[2] - cyl_p0[2] * scale dot = pdx0 * dx + pdy0 * dy + pdz0 * dz if dot < 0.0 or dot > cyl_lengthsq or math.isclose(cyl_lengthsq, 0.0, abs_tol=1e-18): return -1 else: dsq_p0 = pdx0**2 + pdy0**2 + pdz0**2 dsq = dsq_p0 - dot**2 / cyl_lengthsq r_p0 = cyl_p0[3] * scale / 2.0 # simplified cylinder case if(math.isclose(cyl_p0[3], cyl_p1[3], rel_tol = 1e-9)): if dsq > r_p0**2: return -1 else: r_p1 = cyl_p1[3] * scale / 2.0 r_at_test_pnt = r_p0 + (r_p1 - r_p0) * dot / cyl_lengthsq if dsq > r_at_test_pnt**2: return -1 return dsq
################################################################################ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # ################################################################################
[docs]def getCenter(p0, p1, scale): """ Compute the center of p0 and p1, multiply by the scaling factor """ return [(p0[0] + p1[0]) * scale / 2.0, (p0[1] + p1[1]) * scale / 2.0, (p0[2] + p1[2]) * scale / 2.0]