diff --git a/bvh.py b/bvh.py index 40fc5a2..e811bef 100644 --- a/bvh.py +++ b/bvh.py @@ -1,5 +1,8 @@ import re +LINE_SPLIT_RE = re.compile(r'[\r\n]+') +SPACE_SPLIT_RE = re.compile(r'\s+') + class BvhNode: @@ -47,23 +50,24 @@ def __init__(self, data): self.data = data self.root = BvhNode() self.frames = [] + self._joints_cache = None + self._joint_lookup_cache = None + self._joint_channel_start_cache = None + self._joint_channel_index_cache = None + self._joint_channels_cache = None self.tokenize() def tokenize(self): - first_round = [] - accumulator = '' - for char in self.data: - if char not in ('\n', '\r'): - accumulator += char - elif accumulator: - first_round.append(re.split('\\s+', accumulator.strip())) - accumulator = '' node_stack = [self.root] frame_time_found = False node = None - for item in first_round: + for line in LINE_SPLIT_RE.split(self.data): + stripped_line = line.strip() + if not stripped_line: + continue + item = SPACE_SPLIT_RE.split(stripped_line) if frame_time_found: - self.frames.append(item) + self.frames.append([float(value) for value in item]) continue key = item[0] if key == '{': @@ -73,9 +77,50 @@ def tokenize(self): else: node = BvhNode(item) node_stack[-1].add_child(node) - if item[0] == 'Frame' and item[1] == 'Time:': + if key == 'Frame' and len(item) > 1 and item[1] == 'Time:': frame_time_found = True + def _ensure_joint_cache(self): + if self._joints_cache is not None: + return + + joints = [] + joint_lookup = {} + joint_channel_start = {} + joint_channel_index = {} + joint_channels = {} + channel_start = [0] + + root_joint = next(self.root.filter('ROOT'), None) + if root_joint: + def iterate_joints(joint): + joints.append(joint) + joint_name = joint.value[1] + joint_lookup[joint_name] = joint + + channels_data = joint['CHANNELS'] + if channels_data: + channels = channels_data[1:] + for index, channel in enumerate(channels): + joint_channel_index[(joint_name, channel)] = index + joint_channels[joint_name] = channels + joint_channel_start[joint_name] = channel_start[0] + channel_start[0] += int(channels_data[0]) + else: + joint_channels[joint_name] = [] + joint_channel_start[joint_name] = channel_start[0] + + for child in joint.filter('JOINT'): + iterate_joints(child) + + iterate_joints(root_joint) + + self._joints_cache = joints + self._joint_lookup_cache = joint_lookup + self._joint_channel_start_cache = joint_channel_start + self._joint_channel_index_cache = joint_channel_index + self._joint_channels_cache = joint_channels + def search(self, *items): found_nodes = [] @@ -94,24 +139,12 @@ def check_children(node): return found_nodes def get_joints(self): - joints = [] - - def iterate_joints(joint): - joints.append(joint) - for child in joint.filter('JOINT'): - iterate_joints(child) - iterate_joints(next(self.root.filter('ROOT'))) - return joints + self._ensure_joint_cache() + return list(self._joints_cache) def get_joints_names(self): - joints = [] - - def iterate_joints(joint): - joints.append(joint.value[1]) - for child in joint.filter('JOINT'): - iterate_joints(child) - iterate_joints(next(self.root.filter('ROOT'))) - return joints + self._ensure_joint_cache() + return [joint.value[1] for joint in self._joints_cache] def joint_direct_children(self, name): joint = self.get_joint(name) @@ -121,11 +154,10 @@ def get_joint_index(self, name): return self.get_joints().index(self.get_joint(name)) def get_joint(self, name): - found = self.search('ROOT', name) - if not found: - found = self.search('JOINT', name) - if found: - return found[0] + self._ensure_joint_cache() + joint = self._joint_lookup_cache.get(name) + if joint: + return joint raise LookupError('joint not found') def joint_offset(self, name): @@ -134,31 +166,29 @@ def joint_offset(self, name): return (float(offset[0]), float(offset[1]), float(offset[2])) def joint_channels(self, name): - joint = self.get_joint(name) - return joint['CHANNELS'][1:] + self._ensure_joint_cache() + channels = self._joint_channels_cache.get(name) + if channels is not None: + return list(channels) + raise LookupError('joint not found') def get_joint_channels_index(self, joint_name): - index = 0 - for joint in self.get_joints(): - if joint.value[1] == joint_name: - return index - index += int(joint['CHANNELS'][0]) + self._ensure_joint_cache() + index = self._joint_channel_start_cache.get(joint_name) + if index is not None: + return index raise LookupError('joint not found') def get_joint_channel_index(self, joint, channel): - channels = self.joint_channels(joint) - if channel in channels: - channel_index = channels.index(channel) - else: - channel_index = -1 - return channel_index + self._ensure_joint_cache() + return self._joint_channel_index_cache.get((joint, channel), -1) def frame_joint_channel(self, frame_index, joint, channel, value=None): joint_index = self.get_joint_channels_index(joint) channel_index = self.get_joint_channel_index(joint, channel) if channel_index == -1 and value is not None: return value - return float(self.frames[frame_index][joint_index + channel_index]) + return self.frames[frame_index][joint_index + channel_index] def frame_joint_channels(self, frame_index, joint, channels, value=None): values = [] @@ -168,11 +198,7 @@ def frame_joint_channels(self, frame_index, joint, channels, value=None): if channel_index == -1 and value is not None: values.append(value) else: - values.append( - float( - self.frames[frame_index][joint_index + channel_index] - ) - ) + values.append(self.frames[frame_index][joint_index + channel_index]) return values def frames_joint_channels(self, joint, channels, value=None): @@ -185,8 +211,7 @@ def frames_joint_channels(self, joint, channels, value=None): if channel_index == -1 and value is not None: values.append(value) else: - values.append( - float(frame[joint_index + channel_index])) + values.append(frame[joint_index + channel_index]) all_frames.append(values) return all_frames