How does automatic classification of documents using machine learning works?

A friend asked me to explain how does an automatic system for classifying documents, such as AIDR, works.

We are going to do this in three steps, first a preliminary with an example on the risk of having a heart attack, then a little generalities, then the real thing.

Preliminary: predicting heart attack risk

Imagine a doctor with several patients that she has been following for several years. She has a clinical file for each patient in which she has noted the following: whether the patient smokes or not (which she writes as "smokes=y, smokes=n". whether the patient has high blood pressure or not (which she writes as "hypertensive=y, hypertensive=n", and whether the patient practices sports or not (which she writes as "sports=y, sports=n").

Finally, the doctor also notes if the patient has had a heart attack, written as "STROKE=y, STROKE=n":

  • Patient 1: smokes=y, hypertensive=y, sports=n, STROKE=y
  • Patient 2: smokes=y, hypertensive=n, sports=n, STROKE=y
  • Patient 3: smokes=y, hypertensive=n, sports=y, STROKE=n
  • Patient 4: smokes=n, hypertensive=y, sports=y, STROKE=n
  • Patient 5: smokes=n, hypertensive=y, sports=n, STROKE=y

Now, one can extract certain statistics from this data. For instance, patients 3 and 4 practice sports and didn't have a stroke, while patients 1, 2, and 5, don't practice sports and did have a stroke. From this data alone, one could conclude that practicing sports may help prevent a stroke (where the "may help" part doesn't come from this data but just from the recognition that 5 patients is not a lot).

We can also learn that 66% of the patients who smoke had heart strokes in this sample.

Now, if we look for combinations of factors, we can extract more information. For instance, by looking with care at the data, one can realize that, disregarding the practice of sports, everybody in this sample who either smokes or is hypertensive has had a heart attack. Obviously, with more data we can be more certain about how good are the combinations of factors that we learn, in terms of how closely they are related to a certain outcome.

Statistical machine learning

There are so many combinations of factors that even in the small dataset above, with five patients, exploring all the combinations and outcomes is very time consuming. Fortunately, there is where a well established research field, statistical machine learning, that studies precisely this problem.

This research field has studied for years different methods to automatically and quickly find relationships between elements in large-scale data. This process is known as learning, and there are many, many, techniques to do it.

In general, what these methods need in order to be able to learn effectively is: (i) a large amount of data, and (ii) the "right" data. In the example above, the medical doctor who interviewed the patients asked the "right" questions. If she had written instead their eye color or other irrelevant factor, learning something about heart stroke risk would have been much more difficult.

Classifying text

Text classification is not much different. Instead of 3 factors (smokes, hypertensive, and sports), we will have hundreds of thousands of factors, one for each word in the dictionary. The factors will take the form "word=y, word=n" where the "word" can be any word, and we write "y" when the document contains the word and "n" when it doesn't.

The outcomes will be different types of documents. Suppose our documents are tweets and we want to separate those that contain information about damage to infrastructures (DAMAGE=y) from those who don't (DAMAGE=n). Again, you can have the following table, in which for each tweet you have one factor for each word in the tweet, and the outcome has been written by an expert who has looked at the tweet and decided if it contains infrastructure damage or not:

Tweet1: ... building=y, ..., collapsed=y, ..., DAMAGE=y
Tweet2: ... bridge=y, ..., collapsed=y, ..., DAMAGE=y
Tweet3: ... bridge=y, ..., playing=y, ..., DAMAGE=n
Tweet1000: ... bridge=y, ..., hearts=y, ..., DAMAGE=n

Again, we can apply any of the statistical machine learning method to learn what are the combinations of words that indicate the presence of infrastructure damage reports.

That's all. Once we learn those combinations, we can use them automatically to evaluate new tweets. In this case, the learning method will also output a confidence, which you can understand roughly as the percentage of tweets having those factors that were found to have that predicted outcome in the data used to learn (it is more complex than that, but that is a good approximation).

When the data is large, in general it is very difficult for a human to be able to spot a pattern better than what a computer algorithm can do. This is why crafting rules by hand (containing "bridge" implies "DAMAGE" unless the tweet also contains "playing" or "play" or "ace" or "heart" or ...) is not only time-consuming but also tends to yield lower accuracy than automatic methods, and is in general a bad idea.

In the case of text, we also use other factors (we call them "features" or "attributes") in addition to words. For instance, we can take all sequences of two words or three words (which we call "word bi-grams" and "word tri-grams"). We can also look at the position of some words in the phrase, as to whether the word was capitalized or not, how many times it appeared, etc. For the learning method, this is all the same, simply more factors that can be exploited to learn about the data.

Further readings: lots of them, but you can start with the Wikipedia page on decision trees, which is a popular and easy to understand method for statistical machine learning.