Book Image

Test-Driven Python Development

By : Siddharta Govindaraj
Book Image

Test-Driven Python Development

By: Siddharta Govindaraj

Overview of this book

This book starts with a look at the test-driven development process, and how it is different from the traditional way of writing code. All the concepts are presented in the context of a real application that is developed in a step-by-step manner over the course of the book. While exploring the common types of smelly code, we will go back into our example project and clean up the smells that we find. Additionally, we will use mocking to implement the parts of our example project that depend on other systems. Towards the end of the book, we'll take a look at the most common patterns and anti-patterns associated with test-driven development, including integration of test results into the development process.
Table of Contents (20 chapters)
Test-Driven Python Development
Credits
About the Author
Acknowledgments
About the Reviewers
www.PacktPub.com
Preface
Index

Code Smells and Refactoring


This exercise asks us to refactor the Stock class and extract all the moving average related calculations into a new class.

The following is the code that we start with:

def get_crossover_signal(self, on_date):
    NUM_DAYS = self.LONG_TERM_TIMESPAN + 1
    closing_price_list = \
        self.history.get_closing_price_list(on_date, NUM_DAYS)

    if len(closing_price_list) < NUM_DAYS:
        return StockSignal.neutral

    long_term_series = \
        closing_price_list[-self.LONG_TERM_TIMESPAN:]
    prev_long_term_series = \
        closing_price_list[-self.LONG_TERM_TIMESPAN-1:-1]
    short_term_series = \
        closing_price_list[-self.SHORT_TERM_TIMESPAN:]
    prev_short_term_series = \
        closing_price_list[-self.SHORT_TERM_TIMESPAN-1:-1]

    long_term_ma = sum([update.value
                        for update in long_term_series])\
                    /self.LONG_TERM_TIMESPAN
    prev_long_term_ma = sum([update.value
                             for update in prev_long_term_series])\
                         /self.LONG_TERM_TIMESPAN
    short_term_ma = sum([update.value
                         for update in short_term_series])\
                    /self.SHORT_TERM_TIMESPAN
    prev_short_term_ma = sum([update.value
                              for update in prev_short_term_series])\
                         /self.SHORT_TERM_TIMESPAN

    if self._is_crossover_below_to_above(prev_short_term_ma,
                                         prev_long_term_ma,
                                         short_term_ma,
                                         long_term_ma):
                return StockSignal.buy

    if self._is_crossover_below_to_above(prev_long_term_ma,
                                         prev_short_term_ma,
                                         long_term_ma,
                                         short_term_ma):
                return StockSignal.sell

    return StockSignal.neutral

As we can see, there are a number of calculations relating to identifying the moving average window and then calculating the moving average value. These calculations really deserve to be in their own class.

To start with, we create an empty MovingAverage class as follows:

class MovingAverage:
    pass

Now we need to make a design decision on how we want this class to be used. Let us decide that the class should take an underlying timeseries and should be able to compute the moving average at any point based on that timeseries. With this design, the class needs to take the timeseries and the duration of the moving average as parameters, as shown in the following:

def __init__(self, series, timespan):
    self.series = series
    self.timespan = timespan

We can now extract the moving average calculation into this class as follows:

class MovingAverage:
    def __init__(self, series, timespan):
        self.series = series
        self.timespan = timespan

    def value_on(self, end_date):
        moving_average_range = self.series.get_closing_price_list(
                                   end_date, self.timespan)
        if len(moving_average_range) < self.timespan:
            raise NotEnoughDataException("Not enough data")
        price_list = [item.value for item in moving_average_range]
        return sum(price_list)/len(price_list)

This is the same moving average calculation code from Stock.get_signal_crossover. The only notable point is that an exception is raised if there is not enough data to perform the calculation. Let us define this exception in the timeseries.py file as follows:

class NotEnoughDataException(Exception):
    pass

We can now use this method in Stock.get_signal_crossover as follows:

def get_crossover_signal(self, on_date):
    prev_date = on_date - timedelta(1)
    long_term_ma = \
        MovingAverage(self.history, self.LONG_TERM_TIMESPAN)
    short_term_ma = \
        MovingAverage(self.history, self.SHORT_TERM_TIMESPAN)

    try:
        long_term_ma_value = long_term_ma.value_on(on_date)
        prev_long_term_ma_value = long_term_ma.value_on(prev_date)
        short_term_ma_value = short_term_ma.value_on(on_date)
        prev_short_term_ma_value = short_term_ma.value_on(prev_date)
    except NotEnoughDataException:
        return StockSignal.neutral

    if self._is_crossover_below_to_above(prev_short_term_ma_value,
                                         prev_long_term_ma_value,
                                         short_term_ma_value,
                                         long_term_ma_value):
                return StockSignal.buy

    if self._is_crossover_below_to_above(prev_long_term_ma_value,
                                         prev_short_term_ma_value,
                                         long_term_ma_value,
                                         short_term_ma_value):
                return StockSignal.sell

    return StockSignal.neutral

Run the tests, and all 21 tests should pass.

Once we extract the calculation to a class, we find that the temporary variables that we created during Replace Calculation with Temporary Variable section in Chapter 3, Code Smells and Refactoring are not really required. The code is equally self-explanatory without them, so we can now get rid of them, as shown in the following:

def get_crossover_signal(self, on_date):
    prev_date = on_date - timedelta(1)
    long_term_ma = \
        MovingAverage(self.history, self.LONG_TERM_TIMESPAN)
    short_term_ma = \
        MovingAverage(self.history, self.SHORT_TERM_TIMESPAN)

    try:
        if self._is_crossover_below_to_above(
                short_term_ma.value_on(prev_date),
                long_term_ma.value_on(prev_date),
                short_term_ma.value_on(on_date),
                long_term_ma.value_on(on_date)):
            return StockSignal.buy

        if self._is_crossover_below_to_above(
                long_term_ma.value_on(prev_date),
                short_term_ma.value_on(prev_date),
                long_term_ma.value_on(on_date),
                short_term_ma.value_on(on_date)):
            return StockSignal.sell
    except NotEnoughDataException:
        return StockSignal.neutral

    return StockSignal.neutral

A final cleanup: now that we have moving average classes, we can replace the parameters to the _is_crossover_below_to_above method to take the moving average class instead of the individual values. The method now becomes as follows:

def _is_crossover_below_to_above(self, on_date, ma, reference_ma):
    prev_date = on_date - timedelta(1)
    return (ma.value_on(prev_date)
                < reference_ma.value_on(prev_date)
            and ma.value_on(on_date)
                > reference_ma.value_on(on_date))

And we can change the get_crossover_signal method to call this with the new parameters as follows:

def get_crossover_signal(self, on_date):
    long_term_ma = \
        MovingAverage(self.history, self.LONG_TERM_TIMESPAN)
    short_term_ma = \
        MovingAverage(self.history, self.SHORT_TERM_TIMESPAN)

    try:
        if self._is_crossover_below_to_above(
                on_date,
                short_term_ma,
                long_term_ma):
            return StockSignal.buy

        if self._is_crossover_below_to_above(
                on_date,
                long_term_ma,
                short_term_ma):
            return StockSignal.sell
    except NotEnoughDataException:
        return StockSignal.neutral

    return StockSignal.neutral

With this, our Extract Class refactoring is complete.

The get_crossover_signal class is now extremely easy to read and understand.

Notice how the design for the MovingAverage class builds on top of the TimeSeries class that we extracted earlier. As we refactor code and extract classes, we often find that the many classes get reused in other contexts. This is the advantage of having small classes with a single responsibility.

The refactoring into a separate class also allowed us to remove the temporary variables that we had created earlier, and made the parameters for the crossover condition much simpler. Again, these are side effects of having small classes with single responsibilities.