View Javadoc

1   /****************************************************************
2    * Licensed to the Apache Software Foundation (ASF) under one   *
3    * or more contributor license agreements.  See the NOTICE file *
4    * distributed with this work for additional information        *
5    * regarding copyright ownership.  The ASF licenses this file   *
6    * to you under the Apache License, Version 2.0 (the            *
7    * "License"); you may not use this file except in compliance   *
8    * with the License.  You may obtain a copy of the License at   *
9    *                                                              *
10   *   http://www.apache.org/licenses/LICENSE-2.0                 *
11   *                                                              *
12   * Unless required by applicable law or agreed to in writing,   *
13   * software distributed under the License is distributed on an  *
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY       *
15   * KIND, either express or implied.  See the License for the    *
16   * specific language governing permissions and limitations      *
17   * under the License.                                           *
18   ****************************************************************/
19  
20  package org.apache.james.ai.classic;
21  
22  import java.io.IOException;
23  import java.io.Reader;
24  import java.util.ArrayList;
25  import java.util.Collection;
26  import java.util.HashMap;
27  import java.util.HashSet;
28  import java.util.Iterator;
29  import java.util.Map;
30  import java.util.Set;
31  import java.util.SortedSet;
32  import java.util.TreeSet;
33  
34  /**
35   * <p>
36   * Determines probability that text contains Spam.
37   * </p>
38   * 
39   * <p>
40   * Based upon Paul Grahams' <a href="http://www.paulgraham.com/spam.html">A Plan
41   * for Spam</a>. Extended to Paul Grahams' <a
42   * href="http://paulgraham.com/better.html">Better Bayesian Filtering</a>.
43   * </p>
44   * 
45   * <p>
46   * Sample method usage:
47   * </p>
48   * 
49   * <p>
50   * Use: void addHam(Reader) and void addSpam(Reader)
51   * 
52   * methods to build up the Maps of ham & spam tokens/occurrences. Both addHam
53   * and addSpam assume they're reading one message at a time, if you feed more
54   * than one message per call, be sure to adjust the appropriate message counter:
55   * hamMessageCount or spamMessageCount.
56   * 
57   * Then...
58   * </p>
59   * 
60   * <p>
61   * Use: void buildCorpus()
62   * 
63   * to build the final token/probabilities Map.
64   * 
65   * Use your own methods for persistent storage of either the individual ham/spam
66   * corpus & message counts, and/or the final corpus.
67   * 
68   * Then you can...
69   * </p>
70   * 
71   * <p>
72   * Use: double computeSpamProbability(Reader)
73   * 
74   * to determine the probability that a particular text contains spam. A returned
75   * result of 0.9 or above is an indicator that the text was spam.
76   * </p>
77   * 
78   * <p>
79   * If you use persistent storage, use: void setCorpus(Map)
80   * 
81   * before calling computeSpamProbability.
82   * </p>
83   * 
84   * @since 2.3.0
85   */
86  
87  public class BayesianAnalyzer {
88  
89      /**
90       * Number of "interesting" tokens to use to compute overall spamminess
91       * probability.
92       */
93      private final static int MAX_INTERESTING_TOKENS = 15;
94  
95      /**
96       * Minimum probability distance from 0.5 to consider a token "interesting"
97       * to use to compute overall spamminess probability.
98       */
99      private final static double INTERESTINGNESS_THRESHOLD = 0.46;
100 
101     /**
102      * Default token probability to use when a token has not been encountered
103      * before.
104      */
105     private final static double DEFAULT_TOKEN_PROBABILITY = 0.4;
106 
107     /** Map of ham tokens and their occurrences. */
108     private Map<String, Integer> hamTokenCounts = new HashMap<String, Integer>();
109 
110     /** Map of spam tokens and their occurrences. */
111     private Map<String, Integer> spamTokenCounts = new HashMap<String, Integer>();
112 
113     /** Number of ham messages analyzed. */
114     private int hamMessageCount = 0;
115 
116     /** Number of spam messages analyzed. */
117     private int spamMessageCount = 0;
118 
119     /** Final token/probability corpus. */
120     private Map<String, Double> corpus = new HashMap<String, Double>();
121 
122     /**
123      * Inner class for managing Token Probability Strengths during the
124      * computeSpamProbability phase.
125      * 
126      * By probability <i>strength</i> we mean the absolute distance of a
127      * probability from the middle value 0.5.
128      * 
129      * It implements Comparable so that it's sorting is automatic.
130      */
131     private class TokenProbabilityStrength implements Comparable<TokenProbabilityStrength> {
132         /**
133          * Message token.
134          */
135         String token = null;
136 
137         /**
138          * Token's computed probability strength.
139          */
140         double strength = Math.abs(0.5 - DEFAULT_TOKEN_PROBABILITY);
141 
142         /**
143          * Force the natural sort order for this object to be high-to-low.
144          * 
145          * @param anotherTokenProbabilityStrength
146          *            A TokenProbabilityStrength instance to compare this
147          *            instance with.
148          * 
149          * @return The result of the comparison (before, equal, after).
150          */
151         public final int compareTo(TokenProbabilityStrength anotherTokenProbabilityStrength) {
152             int result = (int) ((((TokenProbabilityStrength) anotherTokenProbabilityStrength).strength - strength) * 1000000);
153             if (result == 0) {
154                 return this.token.compareTo(((TokenProbabilityStrength) anotherTokenProbabilityStrength).token);
155             } else {
156                 return result;
157             }
158         }
159 
160         /**
161          * Simple toString () implementation mostly for debugging purposes.
162          * 
163          * @return String representation of this object.
164          */
165         public String toString() {
166             StringBuffer sb = new StringBuffer(30);
167 
168             sb.append(token).append("=").append(strength);
169 
170             return sb.toString();
171         }
172     }
173 
174     /**
175      * Basic class constructor.
176      */
177     public BayesianAnalyzer() {
178     }
179 
180     /**
181      * Public setter for the hamTokenCounts Map.
182      * 
183      * @param hamTokenCounts
184      *            The new ham Token counts Map.
185      */
186     public void setHamTokenCounts(Map<String, Integer> hamTokenCounts) {
187         this.hamTokenCounts = hamTokenCounts;
188     }
189 
190     /**
191      * Public getter for the hamTokenCounts Map.
192      */
193     public Map<String, Integer> getHamTokenCounts() {
194         return this.hamTokenCounts;
195     }
196 
197     /**
198      * Public setter for the spamTokenCounts Map.
199      * 
200      * @param spamTokenCounts
201      *            The new spam Token counts Map.
202      */
203     public void setSpamTokenCounts(Map<String, Integer> spamTokenCounts) {
204         this.spamTokenCounts = spamTokenCounts;
205     }
206 
207     /**
208      * Public getter for the spamTokenCounts Map.
209      */
210     public Map<String, Integer> getSpamTokenCounts() {
211         return this.spamTokenCounts;
212     }
213 
214     /**
215      * Public setter for spamMessageCount.
216      * 
217      * @param spamMessageCount
218      *            The new spam message count.
219      */
220     public void setSpamMessageCount(int spamMessageCount) {
221         this.spamMessageCount = spamMessageCount;
222     }
223 
224     /**
225      * Public getter for spamMessageCount.
226      */
227     public int getSpamMessageCount() {
228         return this.spamMessageCount;
229     }
230 
231     /**
232      * Public setter for hamMessageCount.
233      * 
234      * @param hamMessageCount
235      *            The new ham message count.
236      */
237     public void setHamMessageCount(int hamMessageCount) {
238         this.hamMessageCount = hamMessageCount;
239     }
240 
241     /**
242      * Public getter for hamMessageCount.
243      */
244     public int getHamMessageCount() {
245         return this.hamMessageCount;
246     }
247 
248     /**
249      * Clears all analysis repositories and counters.
250      */
251     public void clear() {
252         corpus.clear();
253 
254         tokenCountsClear();
255 
256         hamMessageCount = 0;
257         spamMessageCount = 0;
258     }
259 
260     /**
261      * Clears token counters.
262      */
263     public void tokenCountsClear() {
264         hamTokenCounts.clear();
265         spamTokenCounts.clear();
266     }
267 
268     /**
269      * Public setter for corpus.
270      * 
271      * @param corpus
272      *            The new corpus.
273      */
274     public void setCorpus(Map<String, Double> corpus) {
275         this.corpus = corpus;
276     }
277 
278     /**
279      * Public getter for corpus.
280      */
281     public Map<String, Double> getCorpus() {
282         return this.corpus;
283     }
284 
285     /**
286      * Builds the corpus from the existing ham & spam counts.
287      */
288     public void buildCorpus() {
289         // Combine the known ham & spam tokens.
290         Set<String> set = new HashSet<String>(hamTokenCounts.size() + spamTokenCounts.size());
291         set.addAll(hamTokenCounts.keySet());
292         set.addAll(spamTokenCounts.keySet());
293         Map<String, Double> tempCorpus = new HashMap<String, Double>(set.size());
294 
295         // Iterate through all the tokens and compute their new
296         // individual probabilities.
297         Iterator<String> i = set.iterator();
298         while (i.hasNext()) {
299             String token = i.next();
300             tempCorpus.put(token, new Double(computeProbability(token)));
301         }
302         setCorpus(tempCorpus);
303     }
304 
305     /**
306      * Adds a message to the ham list.
307      * 
308      * @param stream
309      *            A reader stream on the ham message to analyze
310      * @throws IOException
311      *             If any error occurs
312      */
313     public void addHam(Reader stream) throws java.io.IOException {
314         addTokenOccurrences(stream, hamTokenCounts);
315         hamMessageCount++;
316     }
317 
318     /**
319      * Adds a message to the spam list.
320      * 
321      * @param stream
322      *            A reader stream on the spam message to analyze
323      * @throws IOException
324      *             If any error occurs
325      */
326     public void addSpam(Reader stream) throws java.io.IOException {
327         addTokenOccurrences(stream, spamTokenCounts);
328         spamMessageCount++;
329     }
330 
331     /**
332      * Computes the probability that the stream contains SPAM.
333      * 
334      * @param stream
335      *            The text to be analyzed for Spamminess.
336      * @return A 0.0 - 1.0 probability
337      * @throws IOException
338      *             If any error occurs
339      */
340     public double computeSpamProbability(Reader stream) throws java.io.IOException {
341         // Build a set of the tokens in the Stream.
342         Set<String> tokens = parse(stream);
343 
344         // Get the corpus to use in this run
345         // A new corpus may be being built in the meantime
346         Map<String, Double> workCorpus = getCorpus();
347 
348         // Assign their probabilities from the Corpus (using an additional
349         // calculation to determine spamminess).
350         SortedSet<TokenProbabilityStrength> tokenProbabilityStrengths = getTokenProbabilityStrengths(tokens, workCorpus);
351 
352         // Compute and return the overall probability that the
353         // stream is SPAM.
354         return computeOverallProbability(tokenProbabilityStrengths, workCorpus);
355     }
356 
357     /**
358      * Parses a stream into tokens, and updates the target Map with the
359      * token/counts.
360      * 
361      * @param stream
362      * @param target
363      */
364     private void addTokenOccurrences(Reader stream, Map<String, Integer> target) throws java.io.IOException {
365         String token;
366         String header = "";
367 
368         // Update target with the tokens/count encountered.
369         while ((token = nextToken(stream)) != null) {
370             boolean endingLine = false;
371             if (token.length() > 0 && token.charAt(token.length() - 1) == '\n') {
372                 endingLine = true;
373                 token = token.substring(0, token.length() - 1);
374             }
375 
376             if (token.length() > 0 && header.length() + token.length() < 90 && !allDigits(token)) {
377                 if (token.equals("From:") || token.equals("Return-Path:") || token.equals("Subject:") || token.equals("To:")) {
378                     header = token;
379                     if (!endingLine) {
380                         continue;
381                     }
382                 }
383 
384                 token = header + token;
385 
386                 Integer value = null;
387 
388                 if (target.containsKey(token)) {
389                     value = Integer.valueOf(((Integer) target.get(token)).intValue() + 1);
390                 } else {
391                     value = Integer.valueOf(1);
392                 }
393 
394                 target.put(token, value);
395             }
396 
397             if (endingLine) {
398                 header = "";
399             }
400         }
401     }
402 
403     /**
404      * Parses a stream into tokens, and returns a Set of the unique tokens
405      * encountered.
406      * 
407      * @param stream
408      * @return Set
409      */
410     private Set<String> parse(Reader stream) throws java.io.IOException {
411         Set<String> tokens = new HashSet<String>();
412         String token;
413         String header = "";
414 
415         // Build a Map of tokens encountered.
416         while ((token = nextToken(stream)) != null) {
417             boolean endingLine = false;
418             if (token.length() > 0 && token.charAt(token.length() - 1) == '\n') {
419                 endingLine = true;
420                 token = token.substring(0, token.length() - 1);
421             }
422 
423             if (token.length() > 0 && header.length() + token.length() < 90 && !allDigits(token)) {
424                 if (token.equals("From:") || token.equals("Return-Path:") || token.equals("Subject:") || token.equals("To:")) {
425                     header = token;
426                     if (!endingLine) {
427                         continue;
428                     }
429                 }
430 
431                 token = header + token;
432 
433                 tokens.add(token);
434             }
435 
436             if (endingLine) {
437                 header = "";
438             }
439         }
440 
441         // Return the unique set of tokens encountered.
442         return tokens;
443     }
444 
445     private String nextToken(Reader reader) throws java.io.IOException {
446         StringBuffer token = new StringBuffer();
447         int i;
448         char ch, ch2;
449         boolean previousWasDigit = false;
450         boolean tokenCharFound = false;
451 
452         if (!reader.ready()) {
453             return null;
454         }
455 
456         while ((i = reader.read()) != -1) {
457 
458             ch = (char) i;
459 
460             if (ch == ':') {
461                 String tokenString = token.toString() + ':';
462                 if (tokenString.equals("From:") || tokenString.equals("Return-Path:") || tokenString.equals("Subject:") || tokenString.equals("To:")) {
463                     return tokenString;
464                 }
465             }
466 
467             if (Character.isLetter(ch) || ch == '-' || ch == '$' || ch == '\u20AC' // the
468                                                                                    // EURO
469                                                                                    // symbol
470                     || ch == '!' || ch == '\'') {
471                 tokenCharFound = true;
472                 previousWasDigit = false;
473                 token.append(ch);
474             } else if (Character.isDigit(ch)) {
475                 tokenCharFound = true;
476                 previousWasDigit = true;
477                 token.append(ch);
478             } else if (previousWasDigit && (ch == '.' || ch == ',')) {
479                 reader.mark(1);
480                 previousWasDigit = false;
481                 i = reader.read();
482                 if (i == -1) {
483                     break;
484                 }
485                 ch2 = (char) i;
486                 if (Character.isDigit(ch2)) {
487                     tokenCharFound = true;
488                     previousWasDigit = true;
489                     token.append(ch);
490                     token.append(ch2);
491                 } else {
492                     reader.reset();
493                     break;
494                 }
495             } else if (ch == '\r') {
496                 // cr found, ignore
497             } else if (ch == '\n') {
498                 // eol found
499                 tokenCharFound = true;
500                 previousWasDigit = false;
501                 token.append(ch);
502                 break;
503             } else if (tokenCharFound) {
504                 break;
505             }
506         }
507 
508         if (tokenCharFound) {
509             // System.out.println("Token read: " + token);
510             return token.toString();
511         } else {
512             return null;
513         }
514     }
515 
516     /**
517      * Compute the probability that "token" is SPAM.
518      * 
519      * @param token
520      * @return The probability that the token occurs within spam.
521      */
522     private double computeProbability(String token) {
523         double hamFactor = 0;
524         double spamFactor = 0;
525 
526         boolean foundInHam = false;
527         boolean foundInSpam = false;
528 
529         double minThreshold = 0.01;
530         double maxThreshold = 0.99;
531 
532         if (hamTokenCounts.containsKey(token)) {
533             foundInHam = true;
534         }
535 
536         if (spamTokenCounts.containsKey(token)) {
537             foundInSpam = true;
538         }
539 
540         if (foundInHam) {
541             hamFactor = 2 * ((Integer) hamTokenCounts.get(token)).doubleValue();
542             if (!foundInSpam) {
543                 minThreshold = (hamFactor > 20) ? 0.0001 : 0.0002;
544             }
545         }
546 
547         if (foundInSpam) {
548             spamFactor = ((Integer) spamTokenCounts.get(token)).doubleValue();
549             if (!foundInHam) {
550                 maxThreshold = (spamFactor > 10) ? 0.9999 : 0.9998;
551             }
552         }
553 
554         if ((hamFactor + spamFactor) < 5) {
555             // This token hasn't been seen enough.
556             return 0.4;
557         }
558 
559         double spamFreq = Math.min(1.0, spamFactor / spamMessageCount);
560         double hamFreq = Math.min(1.0, hamFactor / hamMessageCount);
561 
562         return Math.max(minThreshold, Math.min(maxThreshold, (spamFreq / (hamFreq + spamFreq))));
563     }
564 
565     /**
566      * Returns a SortedSet of TokenProbabilityStrength built from the Corpus and
567      * the tokens passed in the "tokens" Set. The ordering is from the highest
568      * strength to the lowest strength.
569      * 
570      * @param tokens
571      * @param workCorpus
572      * @return SortedSet of TokenProbabilityStrength objects.
573      */
574     private SortedSet<TokenProbabilityStrength> getTokenProbabilityStrengths(Set<String> tokens, Map<String, Double> workCorpus) {
575         // Convert to a SortedSet of token probability strengths.
576         SortedSet<TokenProbabilityStrength> tokenProbabilityStrengths = new TreeSet<TokenProbabilityStrength>();
577 
578         Iterator<String> i = tokens.iterator();
579         while (i.hasNext()) {
580             TokenProbabilityStrength tps = new TokenProbabilityStrength();
581 
582             tps.token = (String) i.next();
583 
584             if (workCorpus.containsKey(tps.token)) {
585                 tps.strength = Math.abs(0.5 - ((Double) workCorpus.get(tps.token)).doubleValue());
586             } else {
587                 // This token has never been seen before,
588                 // we'll give it initially the default probability.
589                 Double corpusProbability = new Double(DEFAULT_TOKEN_PROBABILITY);
590                 tps.strength = Math.abs(0.5 - DEFAULT_TOKEN_PROBABILITY);
591                 boolean isTokenDegeneratedFound = false;
592 
593                 Collection<String> degeneratedTokens = buildDegenerated(tps.token);
594                 Iterator<String> iDegenerated = degeneratedTokens.iterator();
595                 String tokenDegenerated = null;
596                 double strengthDegenerated;
597                 while (iDegenerated.hasNext()) {
598                     tokenDegenerated = (String) iDegenerated.next();
599                     if (workCorpus.containsKey(tokenDegenerated)) {
600                         Double probabilityTemp = (Double) workCorpus.get(tokenDegenerated);
601                         strengthDegenerated = Math.abs(0.5 - probabilityTemp.doubleValue());
602                         if (strengthDegenerated > tps.strength) {
603                             isTokenDegeneratedFound = true;
604                             tps.strength = strengthDegenerated;
605                             corpusProbability = probabilityTemp;
606                         }
607                     }
608                 }
609                 // to reduce memory usage, put in the corpus only if the
610                 // probability is different from (stronger than) the default
611                 if (isTokenDegeneratedFound) {
612                     synchronized (workCorpus) {
613                         workCorpus.put(tps.token, corpusProbability);
614                     }
615                 }
616             }
617 
618             tokenProbabilityStrengths.add(tps);
619         }
620 
621         return tokenProbabilityStrengths;
622     }
623 
624     private Collection<String> buildDegenerated(String fullToken) {
625         ArrayList<String> tokens = new ArrayList<String>();
626         String header;
627         String token;
628         String tokenLower;
629 
630         // look for a header string termination
631         int headerEnd = fullToken.indexOf(':');
632         if (headerEnd >= 0) {
633             header = fullToken.substring(0, headerEnd);
634             token = fullToken.substring(headerEnd);
635         } else {
636             header = "";
637             token = fullToken;
638         }
639 
640         // prepare a version of the token containing all lower case (for
641         // performance reasons)
642         tokenLower = token.toLowerCase();
643 
644         int end = token.length();
645         do {
646             if (!token.substring(0, end).equals(tokenLower.substring(0, end))) {
647                 tokens.add(header + tokenLower.substring(0, end));
648                 if (header.length() > 0) {
649                     tokens.add(tokenLower.substring(0, end));
650                 }
651             }
652             if (end > 1 && token.charAt(0) >= 'A' && token.charAt(0) <= 'Z') {
653                 tokens.add(header + token.charAt(0) + tokenLower.substring(1, end));
654                 if (header.length() > 0) {
655                     tokens.add(token.charAt(0) + tokenLower.substring(1, end));
656                 }
657             }
658 
659             if (token.charAt(end - 1) != '!') {
660                 break;
661             }
662 
663             end--;
664 
665             tokens.add(header + token.substring(0, end));
666             if (header.length() > 0) {
667                 tokens.add(token.substring(0, end));
668             }
669         } while (end > 0);
670 
671         return tokens;
672     }
673 
674     /**
675      * Compute the spamminess probability of the interesting tokens in the
676      * tokenProbabilities SortedSet.
677      * 
678      * @param tokenProbabilityStrengths
679      * @param workCorpus
680      * @return Computed spamminess.
681      */
682     private double computeOverallProbability(SortedSet<TokenProbabilityStrength> tokenProbabilityStrengths, Map<String, Double> workCorpus) {
683         double p = 1.0;
684         double np = 1.0;
685         double tempStrength = 0.5;
686         int count = MAX_INTERESTING_TOKENS;
687         Iterator<TokenProbabilityStrength> iterator = tokenProbabilityStrengths.iterator();
688         while ((iterator.hasNext()) && (count-- > 0 || tempStrength >= INTERESTINGNESS_THRESHOLD)) {
689             TokenProbabilityStrength tps = iterator.next();
690             tempStrength = tps.strength;
691 
692             // System.out.println(tps);
693 
694             double theDoubleValue = DEFAULT_TOKEN_PROBABILITY; // initialize it
695                                                                // to the default
696             Double theDoubleObject = (Double) workCorpus.get(tps.token);
697             // if either the original token or a degeneration was found use the
698             // double value, otherwise use the default
699             if (theDoubleObject != null) {
700                 theDoubleValue = theDoubleObject.doubleValue();
701             }
702             p *= theDoubleValue;
703             np *= (1.0 - theDoubleValue);
704             // System.out.println("Token " + tps + ", p=" + theDoubleValue +
705             // ", overall p=" + p / (p + np));
706         }
707 
708         return (p / (p + np));
709     }
710 
711     private boolean allDigits(String s) {
712         for (int i = 0; i < s.length(); i++) {
713             if (!Character.isDigit(s.charAt(i))) {
714                 return false;
715             }
716         }
717         return true;
718     }
719 }