Category: data analytics

Market Basket Analysis with Python and Pandas

Market Basket Analysis with Python and Pandas

If you’ve ever worked with retail data, you’ll most likely have run across the need to perform some market basket analysis (also called Cross-Sell recommendations).  If you aren’t sure what market basket analysis is, I’ve provided a quick overview below.

What is Market Basket Analysis?

In the simplest of terms, market basket analysis looks at retail sales data and determines what products are purchased together. For example, if you sell widgets and want to be able to recommend similar products and/or products that are purchased together, you can perform this type of analysis to be able to understand what products should be recommended when a user views a widget.

You can think of this type of analysis as generating the following ‘rules’:

  • If widget A, then recommend widget B, C and F
  • If widget L, then recommend widget X, Y and R

With these rules, you can then build our recommendation engines for your website, store and salespeople to use when selling products to customers. Market Basket Analysis requires a large amount of transaction data to work well. If you have a large amount of transactional data, you should be able to run a market basket analysis with ease. if you want to learn more about Market Basket Analysis, here’s some additional reading.

In the remainder of this article, I show you how to do this type of analysis using python and pandas.

Market Basket Analysis with Python and Pandas

There are a few approaches that you can take for this type of analysis.  You can use a pre-built library like MLxtend or you can build your own algorithm. I prefer the MLxtend library myself, but recently there’s been some memory issues using pandas and large datasets with MLxtend, so there have been times that I’ve needed to roll my own.

Below, I provide an example of using MLxtend as well as an example of how to roll your own analysis.

Market Basket Analysis with MLxtend

For this example, we’ll use the data set found here. This data-set contains enough data to be useful in understanding market basket analysis but isn’t too large that we can’t use MLxtend (because we can’t unstack the data, which is required to use MLxtend ).

To get started, you’ll need to have pandas and MLxtend installed:

Then, import your libraries:

Now, lets read in the data and then drop any rows that don’t have an invoice number. Lastly, we’ll convert the InvoiceNo column to a string. NOTE: I downloaded the data file from here and stored it in a subdirectory named data.

In this data, there are some invoices that are ‘credits’ instead of ‘debits’ so we want to remove those. They are indentified with “C” in the InvoiceNo field. We can see an example of these types of invoices with the following:

To remove these credit invoices, we can find all invoices with ‘C’ in them, and take the inverse of the results. That can be accomplished with the following line of code:

Now, we are ready to start our market basket analysis. First, we’ll groupby the columns that we want to consider. For the purposes of this analysis, we’ll only look at the United Kingdom orders.

Next, we want to hot encode the data and get 1 transaction per row to prepare to run our mlxtend analysis.

Let’s take a look at the output:

market basket analysis example

Looks like a bunch of zeros. What good is that? Well…its exactly what we want to see. We’ve encoded our data to show when a product is sold with another product. If there is a zero, that means those products haven’t sold together. Before we continue, we want to convert all of our numbers to either a 1 or a 0 (negative numbers are converted to zero, positive numbers are converted to 1). We can do this encoding step with the following function:

And now, we do our final encoding step:

Now, lets find out which items are frequently purchased together. We do this by applying the mlxtend apriori fuuinction to our dataset.

There one thing we need to think about first. the apriori function requires us to provide a minimum level of ‘support’. Support is defined as the percentage of time that an itemset appears in the dataset. If you set support = 50%, you’ll only get itemsets that appear 50% of the time. I like to set support to around 5% when starting out to be able to see some data/results and then adjust from there. Setting the support level to high could lead to very few (or no) results and setting it too low could require an enormous amount of memory to process the data.

In the case of this data, I originally set the min_support to 0.05 but didn’t receive any results, so I changed it to 0.03.

The final step is to build your association rules using the mxltend association_rules function. You can set the metric that you are most interested in (either lift or confidence and set the minimum threshold for the condfidence level (called min_threshold). The min_threshold can be thought of as the level of confidence percentage that you want to return. For example, if you set min_threshold to 1, you will only see rules with 100% confidence. I usually set this to 0.7 to start with.

With this, we generate 16 rules for our market basket analysis.

MLxtend rules for market basket analysis

This gives us a good number of data points to look at for this analysis. Now, what does this tell us?

If you look in the antecedents column and the consequents column, you’ll see names of products. Each rule tells us that the antecedents is sold along with the consequents. You can use this information to build a cross-sell recommendation system that promotes these products with each other on your website (or in person when doing in-person sales).

Without knowing much more about the business that generated this data, we can’t really do much more with it. If you were using your own data, you’d be able to dig a bit deeper to find those rules with higher confidence and/or lift to help you understand the items that are sold together most often and start building strategies to promote those items (or other items if you are trying to grow sales in other areas of your business).

When can you not use MLxtend?

MLxtend can be used anytime you want and it is my preferred approach for market basket analysis. That said, there’s an issue (as of the date of this article) with using pandas with large datasets when performing the step of unstacking the data with this line:

You can see the issue here.

When you run across this issue, you’ll need to find an approach to running a market basket analysis. You can probably find ways to work around the pandas unstack problem, but what I’ve done recently is just roll my own analysis (its actually pretty simple to do). That’s what I’ll show you below.

To get started, we need to import a few more libraries:

Let’s use our original dataframe and assign it to a new df so we know we are working with a completely new data-set vs the above. We’ll use the same United Kingdom filter that we did before

Now, lets grab just the order data. For this,we’ll get the InvoiceNo and StockCode columns since all we care about is whether an item exists on an invoice. Remember, we’ve already removed the ‘credit’ invoices in the above steps so all we have are regular invoices. NOTE: There *will* be differences in the output of this approach vs MLxtend’s approach just like there will be differences in other approaches you might use for market basket analysis.

Now that we have a pandas series of Items, Let’s calculate the item frequency and support values.

Let’s filter out any rows of data that doesn’t have support above our min_support level

We next need to filter out orders that only had 1 items ordered on the invoice, since those items won’t provide any insight into our market basket analysis.

Now, let’s calculate our stats dataframe again with this new order data-set.

Time to do the fun stuff. Calculating the itemsets / item pairs. We’ll create a function that will generate our itemsets and then send our new order dataset through the generator. Then, we calculate the frequency of each item with each other (named frequencyAC) as well as the support (named supportAC). Finally, we filter out the itemsets that are below our min_support level

Finally, we can calculate our association rules. First, let’s unstack our itemsets and create the necessary data columns for support, lift, etc.

Finally, let’s look at our final rules. We want to look at only those items that have confidence > 0.5.

Looking at the rules_over_50 data, we see our final set of rules using our ‘roll your own’ approach.

final rules for market basket

These rules are going to be a bit different than what we get with MLxtend, but that’s OK as it gives us another set of data to look at – and the only set of data to look at when your data is too large to use MLxtend. One extension to this approach would be to add in a step to replace the stockcode numbers with the item descriptions.  I’ll leave it to you to do that work.

Forecasting with Random Forests

When it comes to forecasting data (time series or other types of series), people look to things like basic regression, ARIMA, ARMA, GARCH, or even Prophet but don’t discount the use of Random Forests for forecasting data.

Random Forests are generally considered a classification technique but regression is definitely something that Random Forests can handle.

For this post, I am going to use a dataset found here called Sales Prices of Houses in the City of Windsor (CSV here, description here).  For the purposes of this post, I’ll only use the price and lotsize columns. Note: In a future post, I’m planning to resist this data and perform multivariate regression with Random Forests.

To get started, let’s import all the necessary libraries to get started. As always, you can grab a jupyter notebook to run through this analysis yourself here.

Now, lets load the data:

Again, we are only using two columns from the data set – price and lotsize. Let’s plot this data to take a look at it visually to see if it makes sense to use lotsize as a predictor of price.

Housing Data Visualization

Looking at the data, it looks like a decent guess to think lotsize might forecast price.

Now, lets set up our dataset to get our training and testing data ready.

In the above, we set X and y for the random forest regressor and then set our training and test data. For training data, we are going to take the first 400 data points to train the random forest and then test it on the last 146 data points.

Now, let’s run our random forest regression model.  First, we need to import the Random Forest Regressor from sklearn:

And now….let’s run our Random Forest Regression and see what we get.

Let’s visualize the price and the predicted_price.

price vs predicted price

That’s really not a bad outcome for a wild guess that lotsize predicts price. Visually, it looks pretty good (although there are definitely errors).

Let’s look at the base level error. First, a quick plot of the ‘difference’ between the two.

Price vs Predicted Difference

There are some very large errors in there.  Let’s look at some values like R-Squared and Mean Squared Error. First, lets import the appropriate functions from sklearn.

Now, lets look at R-Squared:

R-Squared is 0.6976…or basically 0.7.  That’s not great but not terribly bad either for a random guess. A value of 0.7 (or 70%) tells you that roughly 70% of the variation of the ‘signal’ is explained by the variable used as a predictor.  That’s really not bad in the grand scheme of things.

I could go on with other calculations for error but the point of this post isn’t to show ‘accuracy’ but to show ‘process’ on how how to use Random Forest for forecasting.

Looks for more posts on using random forests for forecasting.


If you want a very good deep-dive into using Random Forest and other statistical methods for prediction, take a look at The Elements of Statistical Learning: Data Mining, Inference, and Prediction, Second Edition (Amazon Affiliate link)

Quick Tip: SQLAlchemy for MySQL and Pandas

SQLAlchemy LogoFor years I’ve used the mysql-python library for connecting to mysql databases.  It’s worked well for me over the years but there are times when you need speed and/or better connection management that what you get with mysql-python.  That’s where SQLAlchemy comes in.

Before diving into this, if you are doing things that aren’t dependent on speed (e.g., it doesn’t matter if it takes 1 second to connect to the database and grab your data and close the database) then you can easily ignore this tip. That said, if you have multiple connections, that connect time can add up.

For example, I recently had an issue where it was taking 4.5+ seconds to connect to a database, run analysis and spit out the results. That’s not terrible if its something for you only but if its a production system and speed is a requirement, that might be too long (and it IS too long).

When I did some analysis using python’s timer() I found that more than 50% of that 4.5 seconds time was in establishing database connections so I grabbed my trusty SQLAlchemy toolkit and went to work.

For those of you that don’t know, SQLAlchemy is a ‘python SQL toolkit and Object Relational Mapper’ (ORM) that is supposed to make things easier when working with SQL databases. For me, the ORM aspect tends to make things more difficult so I tend to stick with plain SQL queries but the SQL toolkit aspect of SQLAlchemy makes a lot of sense and add some time savings when connecting to a SQL database.

Before we get into the SQLAlchemy aspects, let’s take a second to look at how to connect to a SQL database with the mysql-python connector (or at least take a look at how I do it).

First, let’s setup our import statements. For this, we will import MySQLdb, pandas and pandas.io.sql in order to read SQL data directly into a pandas dataframe.

Next, let’s create a database connection, create a query, execute that query and close that database.

This is a fairly standard approach to reading data into a pandas dataframe from mysql using mysql-python.  This approach is what I had been using before when I was getting 4.5+ seconds as discussed above. Note – there were multiple database calls and some analysis included in that 4.5+ seconds. A basic database call like the above ran in approximately 0.45 seconds in my code that I was trying to improve performance on and establishing the database connection was the majority of that time.

 To improve performance – especially if you will have multiple calls to multiple tables, you can use SQLAlchemy with pandas.   You’ll need to pip install sqlalchemy if you don’t have it installed already. Now, let’s setup our imports:

Now you can setup your connection string to your database for SQLAlchemy, you’d put everything together like the following:

where USER is your username, PW is your password, DBHOST is the database host and  DB is the database you want to connect to.

To setup the persistent connection, you do the following:

Now, you have a connection to your database and you’re ready to go. No need to worry about cursors or opening/closing database connections. SQLAlchemy keeps the connection management aspects in for you.

Now all you need to do is focus on your SQL queries and loading the results into a pandas dataframe.

That’s all it takes.  AND…it’s faster.  In the example above, my database setup / connection / query / closing times dropped from 0.45 seconds to 0.15 seconds.  Times will vary based on what data you are querying and where the database is of course but in this case, all things were the same except for mysql-python being replaced with SQLAlchemy and using the new(ish) read_sql_query function in pandas.

Using this approach, the 4.5+ seconds it took to grab data, analyze the data and return the data was reduced to about 1.5 seconds. Impressive gains for just switching out the connection/management method.