You’ve probably heard about deep learning by now. You’ve heard about what it can do, and you’ve seen it in used when you’ve unlocked your phone with your face, had Facebook suggest which of your friends might be in a picture, or when YouTube has recommended a video to watch. These problems are all similar in the sense that they are solved by “deep learning”—that is, by using artificial neural networks that learn from data to help a computer solve a problem automatically. You may be surprised to hear that deep learning has its roots in concepts dating as far back as the 1950s, with the core fundamental ideas of the field solidified by the 90s.
So why is there so much buzz about neural networks now?
The answer is twofold: we finally have fast enough computers to run large neural networks, and we finally have enough data to teach them something meaningful.
Until recently it was simply impractical from a compute perspective to train and deploy a neural network to do anything useful, so they mostly lived in the realm of theory at various CS and statistics departments across the country. On top of this, deep neural networks need a LOT of data to learn anything meaningful: the most well-known image classification dataset, ImageNet, has over a million images. As more of our waking seconds are spent online, we have more data about every single thing we do in this life: we post pictures, we record videos, we listen to songs, we read articles all online. All of these activities generate data that are stored in vast repositories on the internet.
This reasoning is all well and good as an oversimplified answer. However, I’ve sneakily glossed over a deep problem that appears obvious at first glance: what exactly do I mean by “deep neural networks need a lot of data?”
I can hear you thinking that this is sort of a silly question. That’s ok, but the goal of the rest of this post is to make you realize that it’s not. First of all, “data” is a huge ambiguous word that means many different things. But a machine learning algorithm is not ambiguous—it is precise, and it is needy. When training a neural network to solve most common computer vision tasks, you can’t just feed it huge chunks of random data from the internet. You have to undergo the process of labeling data, which means getting a human to describe the contents of an image in a way a computer can understand.
For example, let’s say that we want to train a computer to identify whether a particular image is of a cat or a dog. We can’t just download a bunch of pictures off of Instagram and hope to get anything meaningful out. We have to curate a large collection of images of both cats and dogs, and get a person to label whether each image has a cat or a dog in it.
Ok, now that we have that out of the way, it’s easy to define what I mean when I say we need a lot of data. We just need to collect a bunch of images of cats and dogs with labels, and then we’re good, right? Wrong!
In order to understand why this is, you first have to understand how neural networks tend to operate. They are both lazy and stupid. They will learn the easiest possible rules to accomplish their task, and nothing more beyond that. To illustrate that, look at the following sets of images:
If I asked you to tell me the difference between those two sets of images, you would say that all the images in Set 1 have a cat in them, and all images in Set 2 have a dog in them. Very good! You are smarter than a neural network. Imagine that we had millions of images like these and trained a neural network to determine whether a picture was a cat vs a dog by giving it all of these labeled examples. Do you think that it would successfully learn the difference between a cat and a dog? No! And the reason is this:
A neural network learns the simplest possible rule to accomplish its job.
It’s much harder to tell the difference between two lil’ fuzzy animals with pointy ears and four legs than it is to notice that all the images in set 1 are taken inside whereas all the images in set 2 are taken in a field of grass. Our neural network could learn to decide whether an image is a cat or a dog simply by counting the number of green pixels, and do just as well on its goal on all the data we’ve given it. And if it learned this false distinction, it would be incorrect the second we showed it a cat outside or a dog inside.
So even if we have millions of images, our data simply isn’t big enough to teach our neural network anything meaningful. In fact, even with infinitely many images, it would still fail. So our data needs to be biggerin a different way: not just in the amount of images, but in the diversity among those images.
As another example, let’s consider the following sets of images.
You might be thinking, “This is looking pretty good! We’ve solved the previous problem by adding in more pictures where the cats are in green fields, just like the dogs are, so hopefully this time around our neural network will learn to find actual differences between cats and dogs to accomplish its task!”
Well, I am sorry to disappoint you, but it would find an even subtler rule to learn instead: in Set 1, all the cats are far away, whereas in Set 2, we have a bunch of closeups of dogs. So the neural network would probably learn to differentiate between the two by just looking at the size of the animal in the picture. Again, it would fail the instant we showed it a picture of a dog that’s far away.
This is all to illustrate that big data isn’t enough, at least if ‘big’ just means the number of raw gigabytes in our dataset. We need diverse data, and we have to carefully curate our datasets to help teach the neural network something real and meaningful that will actually generalize to solve unseen examples in the future.
Hopefully this post has helped to educate why it’s hard and why it’s important, while also hinting that if you put in the effort to diversify your dataset, you get results that are worth it!