Technical01/10/2023

Deep Learning for Tabular Data: An Overview

Michael Clark
Machine Learning Scientist

Deep Learning (DL) has made remarkable strides this past year, particularly in computer vision and natural language processing, where we saw tools like Stable Diffusion creating beautiful surrealistic images from a prompt, others solving complex protein-folding problems, to more recently where one can write a short story or create a working app from a text prompt and ChatGPT. Exciting times indeed!

Then again, perhaps we shouldn’t be surprised as it's these domains where DL has shone brightly in the past also. Interestingly, when we switch to tabular data — the kind of data more ubiquitous in data science — deep learning has not been as successful.

We can think of data in terms of general types. A common distinction is structured vs. unstructured. A common example of text as unstructured data and tabular data as structured, or even more generally as tabular (structured) vs. anything else (unstructured). This distinction does not seem evenly applied, and unstructured data can and often is given a ‘structure’ in order to be able to use it for modeling. Perhaps a better distinction for our purposes is homogeneous vs. heterogeneous. In this case, the data inputs are all of the exact same type (homogeneous) or of varying types (heterogeneous). Homogeneous would include images, text, and even some structured data settings (e.g. multivariate times series, like financial data). Heterogeneous is more complex, and the kind we typically observe with tabular data. Even in homogeneous tabular data settings, it’s usually trivial to add additional useful features from other sources (e.g. adding seasonal or firm characteristics to the financial data).

Tabular data is characterized by feature heterogeneity- mixed input types coming from possibly disparate sources, and often varying wildly in reliability and integrity. For example, we might have observational data on individuals and their interactions with a website or app. For these individuals we could have demographic information (e.g., age, what state they live in), information on how often they visit or interact with a website (e.g., how many times they clicked on an ad in a month), date-related information about when they visited, and more. Some features are expressed in categories, others are numeric but highly skewed, some features come in clusters (e.g., multiple survey items about a specific topic), some may be unreliable (e.g., asking how favorably they view a particular topic).

Among the challenges with tabular data are the mixed data types, feature inter-correlations and redundancy, missing values, poor scaling, and noisy measurement. These challenges often lead to data situations with low signal-to-noise ratio, making it difficult to devise a model that will easily perform well.

On the surface, one might expect that DL, a special case of the universal function approximators better known as neural networks, should make quick work of tabular data, given their immense successes in other domains. After all, these are the kind of data analyzed by millions each day in spreadsheets around the world. Yet this type of data has proven consistently challenging for DL due to the challenges inherent in such data.

Techniques for analyzing tabular data

Aside from deep learning, what are some typical ways we might analyze this type of data? We can start with linear/logistic regression or anything along those lines (generalized linear/additive models). These approaches will be easiest to conduct, allow for relatively easy interpretation, have common methods to understand the uncertainty in the results, and are easily applied to the small data situations that are common among tabular data. However, these techniques are not usually very performant compared to other tools we might use. Why? Because in most data situations, we usually have interactions and nonlinear relationships, which these tools require additional steps to address. We also need to add penalization and/or other techniques to make better predictions beyond the data we have, also known as “generalization”. Some of the issues are easily and typically addressed, like adding some regularization, but others can be more difficult, for example, handling interactions and nonlinearities becomes tedious and problematic as the number of features grows. In addition, standard statistical methods become more difficult to apply in very large data situations without compromising some of the inherent benefits (e.g. easy uncertainty estimation).

More complex models for tabular data

If we're primarily interested in predictive performance, a standard approach to tabular data is via boosting, with popular libraries such as lightgbm, xgboost, catboost, all of which are available in Python and R. Boosting is fast and will generally do much better than linear model approaches out of the box. When appropriately tuned, they are hard to beat, and differences between the specific boosting implementations mentioned above are minor in practice. These tools are also applicable to large data situations with possibly dozens of features, where standard statistical methods will struggle or fail. In addition, we can also get measures of feature importance, and use other approaches to aid interpretability.

As we increase the complexity of the model, other aspects of the endeavor start to suffer. It becomes more difficult to get straightforward interpretations of feature-target relationships due to complex interactions and feature inter-correlations. It also becomes more difficult to get an understanding of the uncertainty in the results without adding to computational complexity, though there are methods there as well. This is just to say that just because we get a good performing model, doesn’t mean that we get everything we’d want.

Deep learning for analyzing tabular data

Various methods for Deep Learning have been applied to tabular data, from standard multi-layer perceptrons to complex architectures and ensembles of those. Unfortunately none consistently beat boosting methods, and on closer inspection, the performance is actually worse, or boosting is better, than what is touted in the papers. All too often the hype is just that. Many of these papers are relying on several decimal places differences in performance to report a win (less than the actual precision of the metric - would you actually care about .xxx3 vs. a .xxx1 result in accuracy?), and even using or boosting tools from years ago. I have even seen a case where someone was able to produce better results for boosting than a paper reported, without even hyper-parameter tuning, and those results beat the paper’s reported DL model performance!

We don’t want to fault the technique for the usage of it, and beyond the hype, where DL seems to excel is with strictly numeric data with notable nonlinear relationships, including interactions (see, for example, my meta-anaysis of several paper results). Consider the case of an experimental data situation with a simple treatment condition (control vs treatment) and a another two group feature (e.g., pre-post setting). Assuming good experimental design and control, the simplest model we could run is just y ~ trt + x + trt*x. A DL model can’t do much with that, though it is a valid and useful model applicable to many experimental settings. However, in even slightly more complex settings, for example, 10 features with notable interactions and nonlinear relationships with the outcome, standard statistical models will have difficulty modeling the complexity without a lot of work. We could apply DL techniques in that case, though even then, performance may only slightly improve upon boosting results, which also can model those complexities.

Summary & Recommendations

At present, deep learning for tabular data may have a place, but there has been no major success with tabular data, and it's not clear what remaining tricks there are to pull off. But time will tell, and given recent DL advances in other domains it seems anything is possible. As such, for now our recommendation is to pursue deep learning for tabular data when you have a large matrix of homogeneous data, want to take advantage of the tools’ ability to handle large data situations, or want to use it as part of an ensemble with other methods. You may also want to use a DL approach because part of your problem includes NLP or computer vision, and want to use consistent tools or even use them together to predict some target.

The following can serve as some general guidelines for working with tabular data.

Consider using deep learning if:

• Your tabular data are homogeneous
• You want to use a standard model (like linear regression) with very large data
• You'd like to combine your tabular data with other DL models/tools (e.g. NLP)

If the goal is performance with tabular data:

• Boosting should still be the go to, and should remain a baseline benchmark for any deep learning approach
• You should consider ensembling approaches, including deep learning

If goal is interpretability:

• Consider using GLM/GAMMs
• Realize that if the data relationships are complex, easy interpretability may not be achievable