Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 100 additions & 72 deletions gatetools/merge_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
# -----------------------------------------------------------------------------


import gatetools as gt
import logging

import numpy as np
import tqdm
import logging
logger=logging.getLogger(__name__)
import uproot

import gatetools as gt

logger = logging.getLogger(__name__)


def unicity(root_keys):
"""
Expand All @@ -38,34 +43,30 @@ def unicity(root_keys):
name = name[0]
if not name in root_array:
root_array.append(name)
return(root_array)
return root_array


def merge_root(rootfiles, outputfile, incrementRunId=False):
"""
Merge root files in output files
"""
try:
import uproot
except:
print("uproot4 is mandatory to merge root file. Please, do:")
print("pip install uproot")

uproot.default_library = "np"

out = uproot.recreate(outputfile)

#Previous ID values to be able to increment runIn or EventId
# Previous ID values to be able to increment runIn or EventId
previousId = {}

#create the dict reading all input root files
trees = {} #TTree with TBranch
hists = {} #Directory with THist
pbar = tqdm.tqdm(total = len(rootfiles))
# create the dict reading all input root files
trees = {} # TTree with TBranch
hists = {} # Directory with THist
pbar = tqdm.tqdm(total=len(rootfiles))
for file in rootfiles:
root = uproot.open(file)
root_keys = unicity(root.keys())
for tree in root_keys:
if hasattr(root[tree], 'keys'):
if hasattr(root[tree], "keys"):
if not tree in trees:
trees[tree] = {}
trees[tree]["rootDictType"] = {}
Expand All @@ -75,67 +76,95 @@ def merge_root(rootfiles, outputfile, incrementRunId=False):
hists[tree]["rootDictValue"] = {}
previousId[tree] = {}
for branch in root[tree].keys():
if isinstance(root[tree],uproot.reading.ReadOnlyDirectory):
if isinstance(root[tree], uproot.reading.ReadOnlyDirectory):
array = root[tree][branch].values()
if len(array) > 0:
branchName = tree + "/" + branch
if type(array[0]) is type('c'):
if type(array[0]) is type("c"):
array = np.array([0 for xi in array])
if not branchName in hists[tree]["rootDictType"]:
hists[tree]["rootDictType"][branchName] = root[tree][branch].to_numpy()
hists[tree]["rootDictValue"][branchName] = np.zeros(array.shape)
hists[tree]["rootDictType"][branchName] = root[tree][
branch
].to_numpy()
hists[tree]["rootDictValue"][branchName] = np.zeros(
array.shape
)
hists[tree]["rootDictValue"][branchName] += array
else:
array = root[tree][branch].array(library="np")
if len(array) > 0 and not (type(array[0]) is type(np.ndarray(2,))):
if type(array[0]) is type('c'):
if len(array) > 0 and (
type(array[0])
is not type(
np.ndarray(
2,
)
)
):
if type(array[0]) is type("c"):
array = np.array([0 for xi in array])
if not branch in trees[tree]["rootDictType"]:
if branch not in trees[tree]["rootDictType"]:
trees[tree]["rootDictType"][branch] = type(array[0])
trees[tree]["rootDictValue"][branch] = np.array([])
if (not incrementRunId and branch.startswith('eventID')) or (incrementRunId and branch.startswith('runID')):
if not branch in previousId[tree]:
if (
not incrementRunId and branch.startswith("eventID")
) or (incrementRunId and branch.startswith("runID")):
if branch not in previousId[tree]:
previousId[tree][branch] = 0
array += previousId[tree][branch]
previousId[tree][branch] = max(array) +1
trees[tree]["rootDictValue"][branch] = np.append(trees[tree]["rootDictValue"][branch], array)
previousId[tree][branch] = max(array) + 1
trees[tree]["rootDictValue"][branch] = np.append(
trees[tree]["rootDictValue"][branch], array
)
pbar.update(1)
pbar.close()

#Set the dict in the output root file
# Set the dict in the output root file
for tree in trees:
if not trees[tree]["rootDictValue"] == {} or not trees[tree]["rootDictType"] == {}:
#out.mktree(tree, trees[tree]["rootDictType"])
out[tree] = trees[tree]["rootDictValue"]
if (
not trees[tree]["rootDictValue"] == {}
or not trees[tree]["rootDictType"] == {}
):
out.mktree(tree, trees[tree]["rootDictType"])
out[tree].extend(trees[tree]["rootDictValue"])
for hist in hists:
if not hists[hist]["rootDictValue"] == {} or not hists[hist]["rootDictType"] == {}:
if (
not hists[hist]["rootDictValue"] == {}
or not hists[hist]["rootDictType"] == {}
):
for branch in hists[hist]["rootDictValue"]:
for i in range(len(hists[hist]["rootDictValue"][branch])):
hists[hist]["rootDictType"][branch][0][i] = hists[hist]["rootDictValue"][branch][i]
out[branch[:-2]] = hists[hist]["rootDictType"][branch]
hists[hist]["rootDictType"][branch][0][i] = hists[hist][
"rootDictValue"
][branch][i]
out.mktree(branch[:-2], hists[hist]["rootDictType"][branch])
out[branch[:-2]].extend(hists[hist]["rootDictType"][branch])


#####################################################################################
import unittest
import tempfile
import wget
import os
import shutil
import tempfile
import unittest

import numpy as np
import uproot
import wget

from .logging_conf import LoggedTestCase


class Test_MergeRoot(LoggedTestCase):
def test_merge_root_phsp(self):
try:
import uproot
except:
print("uproot4 is mandatory to merge root file. Please, do:")
print("pip install uproot")

logger.info('Test_MergeRoot test_merge_root_phsp')
logger.info("Test_MergeRoot test_merge_root_phsp")
tmpdirpath = tempfile.mkdtemp()
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/phsp.root?inline=false", out=tmpdirpath, bar=None)
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"))
filenameRoot = wget.download(
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/phsp.root?inline=false",
out=tmpdirpath,
bar=None,
)
gt.merge_root(
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root")
)
input = uproot.open(filenameRoot)
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
self.assertTrue(output.keys() == input.keys())
Expand All @@ -144,57 +173,56 @@ def test_merge_root_phsp(self):
self.assertTrue(outputTree.keys() == inputTree.keys())
inputBranch = inputTree[inputTree.keys()[1]].array(library="np")
outputBranch = outputTree[outputTree.keys()[1]].array(library="np")
self.assertTrue(2*len(inputBranch) == len(outputBranch))
self.assertTrue(2 * len(inputBranch) == len(outputBranch))
shutil.rmtree(tmpdirpath)

def test_merge_root_pet_incrementEvent(self):
try:
import uproot
except:
print("uproot4 is mandatory to merge root file. Please, do:")
print("pip install uproot")

logger.info('Test_MergeRoot test_merge_root_pet')
logger.info("Test_MergeRoot test_merge_root_pet")
tmpdirpath = tempfile.mkdtemp()
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false", out=tmpdirpath, bar=None)
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"))
filenameRoot = wget.download(
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false",
out=tmpdirpath,
bar=None,
)
gt.merge_root(
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root")
)
input = uproot.open(filenameRoot)
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
inputTree = input[input.keys()[0]]
outputTree = output[output.keys()[0]]
inputRunBranch = inputTree[inputTree.keys()[0]].array(library="np")
outputRunBranch = outputTree[outputTree.keys()[0]].array(library="np")
self.assertTrue(max(inputRunBranch) == max(outputRunBranch))
self.assertTrue(2*len(inputRunBranch) == len(outputRunBranch))
self.assertTrue(2 * len(inputRunBranch) == len(outputRunBranch))
inputEventBranch = inputTree[inputTree.keys()[1]].array(library="np")
outputEventBranch = outputTree[outputTree.keys()[1]].array(library="np")
self.assertTrue(2*max(inputEventBranch)+1 == max(outputEventBranch))
self.assertTrue(2*len(inputEventBranch) == len(outputEventBranch))
self.assertTrue(2 * max(inputEventBranch) + 1 == max(outputEventBranch))
self.assertTrue(2 * len(inputEventBranch) == len(outputEventBranch))
shutil.rmtree(tmpdirpath)

def test_merge_root_pet_incrementRun(self):
try:
import uproot
except:
print("uproot4 is mandatory to merge root file. Please, do:")
print("pip install uproot")

logger.info('Test_MergeRoot test_merge_root_pet')
logger.info("Test_MergeRoot test_merge_root_pet")
tmpdirpath = tempfile.mkdtemp()
print(tmpdirpath)
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false", out=tmpdirpath, bar=None)
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"), True)
filenameRoot = wget.download(
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false",
out=tmpdirpath,
bar=None,
)
gt.merge_root(
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"), True
)
input = uproot.open(filenameRoot)
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
inputTree = input[input.keys()[0]]
outputTree = output[output.keys()[0]]
inputRunBranch = inputTree[inputTree.keys()[0]].array(library="np")
outputRunBranch = outputTree[outputTree.keys()[0]].array(library="np")
self.assertTrue(2*max(inputRunBranch)+1 == max(outputRunBranch))
self.assertTrue(2*len(inputRunBranch) == len(outputRunBranch))
self.assertTrue(2 * max(inputRunBranch) + 1 == max(outputRunBranch))
self.assertTrue(2 * len(inputRunBranch) == len(outputRunBranch))
inputEventBranch = inputTree[inputTree.keys()[1]].array(library="np")
outputEventBranch = outputTree[outputTree.keys()[1]].array(library="np")
self.assertTrue(max(inputEventBranch) == max(outputEventBranch))
self.assertTrue(2*len(inputEventBranch) == len(outputEventBranch))
#shutil.rmtree(tmpdirpath)

self.assertTrue(2 * len(inputEventBranch) == len(outputEventBranch))
# shutil.rmtree(tmpdirpath)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,3 @@ gt_digi_mac_converter = "gatetools.bin.gt_digi_mac_converter:convert_macro"
gate_split_and_run = "gatetools.clustertools.gate_split_and_run:runJobs"
opengate_run = "gatetools.clustertools.opengate_run:runJobs_click"
computeElapsedTime = "gatetools.clustertools.computeElapsedTime:computeElapsedTime"

Loading