diff --git a/axelrod/interaction_utils.py b/axelrod/interaction_utils.py index a9449f450..dafefc0c2 100644 --- a/axelrod/interaction_utils.py +++ b/axelrod/interaction_utils.py @@ -9,6 +9,7 @@ interactions. """ import csv +from collections import Counter from .game import Game from axelrod import Actions @@ -87,6 +88,57 @@ def compute_normalised_cooperation(interactions): return normalised_cooperation +def compute_state_distribution(interactions): + """ + Returns the count of each state for a set of interactions. + + Parameters + ---------- + interactions : list of tuples + A list containing the interactions of the match as shown at the top of + this file. + + Returns + ---------- + Counter(interactions) : Counter Object + Dictionary where the keys are the states and the values are the number + of times that state occurs. + """ + if not interactions: + return None + return Counter(interactions) + + +def compute_normalised_state_distribution(interactions): + """ + Returns the normalized count of each state for a set of interactions. + + Parameters + ---------- + interactions : list of tuples + A list containing the interactions of the match as shown at the top of + this file. + + Returns + ---------- + normalized_count : Counter Object + Dictionary where the keys are the states and the values are a normalized + count of the number of times that state occurs. + """ + if not interactions: + return None + + interactions_count = Counter(interactions) + total = sum(interactions_count.values(), 0.0) + # By starting the sum with 0.0 we make sure total is a floating point value, + # avoiding the Python 2 floor division behaviour of / with integer operands + # (Stack Overflow) + + normalized_count = Counter({key: value / total for key, value in + interactions_count.items()}) + return normalized_count + + def sparkline(actions, c_symbol=u'█', d_symbol=u' '): return u''.join([ c_symbol if play == 'C' else d_symbol for play in actions]) diff --git a/axelrod/match.py b/axelrod/match.py index 8b70b928e..1fe5f4e7b 100644 --- a/axelrod/match.py +++ b/axelrod/match.py @@ -159,6 +159,18 @@ def normalised_cooperation(self): """Returns the count of cooperations by each player per turn""" return iu.compute_normalised_cooperation(self.result) + def state_distribution(self): + """ + Returns the count of each state for a set of interactions. + """ + return iu.compute_state_distribution(self.result) + + def normalised_state_distribution(self): + """ + Returns the normalized count of each state for a set of interactions. + """ + return iu.compute_normalised_state_distribution(self.result) + def sparklines(self, c_symbol=u'█', d_symbol=u' '): return iu.compute_sparklines(self.result, c_symbol, d_symbol) diff --git a/axelrod/tests/unit/test_interaction_utils.py b/axelrod/tests/unit/test_interaction_utils.py index 11aab59f2..82cba14d3 100644 --- a/axelrod/tests/unit/test_interaction_utils.py +++ b/axelrod/tests/unit/test_interaction_utils.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import unittest import tempfile +from collections import Counter import axelrod import axelrod.interaction_utils as iu @@ -17,6 +18,14 @@ class TestMatch(unittest.TestCase): winners = [False, 0, 1, None] cooperations = [(1, 1), (0, 2), (2, 1), None] normalised_cooperations = [(.5, .5), (0, 1), (1, .5), None] + state_distribution = [Counter({('C', 'D'): 1, ('D', 'C'): 1}), + Counter({('D', 'C'): 2}), + Counter({('C', 'C'): 1, ('C', 'D'): 1}), + None] + normalised_state_distribution = [Counter({('C', 'D'): 0.5, ('D', 'C'): 0.5}), + Counter({('D', 'C'): 1.0}), + Counter({('C', 'C'): 0.5, ('C', 'D'): 0.5}), + None] sparklines = [ u'█ \n █', u' \n██', u'██\n█ ', None ] @@ -46,6 +55,14 @@ def test_compute_normalised_cooperations(self): for inter, coop in zip(self.interactions, self.normalised_cooperations): self.assertEqual(coop, iu.compute_normalised_cooperation(inter)) + def test_compute_state_distribution(self): + for inter, dist in zip(self.interactions, self.state_distribution): + self.assertEqual(dist, iu.compute_state_distribution(inter)) + + def test_compute_normalised_state_distribution(self): + for inter, dist in zip(self.interactions, self.normalised_state_distribution): + self.assertEqual(dist, iu.compute_normalised_state_distribution(inter)) + def test_compute_sparklines(self): for inter, spark in zip(self.interactions, self.sparklines): self.assertEqual(spark, iu.compute_sparklines(inter))