Package nltk :: Package classify :: Module weka
[hide private]
[frames] | no frames]

Source Code for Module nltk.classify.weka

  1  # Natural Language Toolkit: Interface to Weka Classsifiers 
  2  # 
  3  # Copyright (C) 2001-2008 NLTK Project 
  4  # Author: Edward Loper <[email protected]> 
  5  # URL: <http://nltk.org> 
  6  # For license information, see LICENSE.TXT 
  7  # 
  8  # $Id: naivebayes.py 2063 2004-07-17 21:02:24Z edloper $ 
  9   
 10  """ 
 11  Classifiers that make use of the external 'Weka' package. 
 12  """ 
 13   
 14  import time 
 15  import tempfile 
 16  import os 
 17  import os.path 
 18  import subprocess 
 19  import re 
 20  import zipfile 
 21   
 22  from nltk.probability import * 
 23  from nltk.internals import java, config_java 
 24   
 25  from api import * 
 26   
 27  _weka_classpath = None 
 28  _weka_search = ['.', 
 29                  '/usr/share/weka', 
 30                  '/usr/local/share/weka', 
 31                  '/usr/lib/weka', 
 32                  '/usr/local/lib/weka',] 
33 -def config_weka(classpath=None):
34 global _weka_classpath 35 36 # Make sure java's configured first. 37 config_java() 38 39 if classpath is not None: 40 _weka_classpath = classpath 41 42 if _weka_classpath is None: 43 searchpath = _weka_search 44 if 'WEKAHOME' in os.environ: 45 searchpath.insert(0, os.environ['WEKAHOME']) 46 47 for path in searchpath: 48 if os.path.exists(os.path.join(path, 'weka.jar')): 49 _weka_classpath = os.path.join(path, 'weka.jar') 50 version = _check_weka_version(_weka_classpath) 51 if version: 52 print ('[Found Weka: %s (version %s)]' % 53 (_weka_classpath, version)) 54 else: 55 print '[Found Weka: %s]' % _weka_classpath 56 _check_weka_version(_weka_classpath) 57 58 if _weka_classpath is None: 59 raise LookupError('Unable to find weka.jar! Use config_weka() ' 60 'or set the WEKAHOME environment variable. ' 61 'For more information about Weka, please see ' 62 'http://www.cs.waikato.ac.nz/ml/weka/')
63
64 -def _check_weka_version(jar):
65 try: 66 zf = zipfile.ZipFile(jar) 67 except SystemExit, KeyboardInterrupt: 68 raise 69 except: 70 return None 71 try: 72 try: 73 return zf.read('weka/core/version.txt') 74 except KeyError: 75 return None 76 finally: 77 zf.close()
78
79 -class WekaClassifier(ClassifierI):
80 - def __init__(self, formatter, model_filename):
81 self._formatter = formatter 82 self._model = model_filename
83
84 - def batch_prob_classify(self, featuresets):
85 return self._batch_classify(featuresets, ['-p', '0', '-distribution'])
86
87 - def batch_classify(self, featuresets):
88 return self._batch_classify(featuresets, ['-p', '0'])
89
90 - def _batch_classify(self, featuresets, options):
91 # Make sure we can find java & weka. 92 config_weka() 93 94 temp_dir = tempfile.mkdtemp() 95 try: 96 # Write the test data file. 97 test_filename = os.path.join(temp_dir, 'test.arff') 98 self._formatter.write(test_filename, featuresets) 99 100 # Call weka to classify the data. 101 cmd = ['weka.classifiers.bayes.NaiveBayes', 102 '-l', self._model, '-T', test_filename] + options 103 (stdout, stderr) = java(cmd, classpath=_weka_classpath, 104 stdout=subprocess.PIPE, 105 stderr=subprocess.PIPE) 106 107 # Check if something went wrong: 108 if stderr and not stdout: 109 if 'Illegal options: -distribution' in stderr: 110 raise ValueError('The installed verison of weka does ' 111 'not support probability distribution ' 112 'output.') 113 else: 114 raise ValueError('Weka failed to generate output:\n%s' 115 % stderr) 116 117 # Parse weka's output. 118 return self.parse_weka_output(stdout.split('\n')) 119 120 finally: 121 for f in os.listdir(temp_dir): 122 os.remove(os.path.join(temp_dir, f)) 123 os.rmdir(temp_dir)
124
125 - def parse_weka_distribution(self, s):
126 probs = [float(v) for v in re.split('[*,]+', s) if v.strip()] 127 probs = dict(zip(self._formatter.labels(), probs)) 128 return DictionaryProbDist(probs)
129
130 - def parse_weka_output(self, lines):
131 if lines[0].split() == ['inst#', 'actual', 'predicted', 132 'error', 'prediction']: 133 return [line.split()[2].split(':')[1] 134 for line in lines[1:] if line.strip()] 135 elif lines[0].split() == ['inst#', 'actual', 'predicted', 136 'error', 'distribution']: 137 return [self.parse_weka_distribution(line.split()[-1]) 138 for line in lines[1:] if line.strip()] 139 140 # is this safe:? 141 elif re.match(r'^0 \w+ [01]\.[0-9]* \?\s*$', lines[0]): 142 return [line.split()[1] for line in lines if line.strip()] 143 144 else: 145 for line in lines[:10]: print line 146 raise ValueError('Unhandled output format -- your version ' 147 'of weka may not be supported.\n' 148 ' Header: %s' % lines[0])
149 150 151 # [xx] full list of classifiers (some may be abstract?): 152 # ADTree, AODE, BayesNet, ComplementNaiveBayes, ConjunctiveRule, 153 # DecisionStump, DecisionTable, HyperPipes, IB1, IBk, Id3, J48, 154 # JRip, KStar, LBR, LeastMedSq, LinearRegression, LMT, Logistic, 155 # LogisticBase, M5Base, MultilayerPerceptron, 156 # MultipleClassifiersCombiner, NaiveBayes, NaiveBayesMultinomial, 157 # NaiveBayesSimple, NBTree, NNge, OneR, PaceRegression, PART, 158 # PreConstructedLinearModel, Prism, RandomForest, 159 # RandomizableClassifier, RandomTree, RBFNetwork, REPTree, Ridor, 160 # RuleNode, SimpleLinearRegression, SimpleLogistic, 161 # SingleClassifierEnhancer, SMO, SMOreg, UserClassifier, VFI, 162 # VotedPerceptron, Winnow, ZeroR 163 164 _CLASSIFIER_CLASS = { 165 'naivebayes': 'weka.classifiers.bayes.NaiveBayes', 166 'C4.5': 'weka.classifiers.trees.J48', 167 'log_regression': 'weka.classifiers.functions.Logistic', 168 'svm': 'weka.classifiers.functions.SMO', 169 'kstar': 'weka.classifiers.lazy.kstar', 170 'ripper': 'weka.classifiers.rules.JRip', 171 } 172 @classmethod
173 - def train(cls, model_filename, featuresets, 174 classifier='naivebayes', options=[], quiet=True):
175 # Make sure we can find java & weka. 176 config_weka() 177 178 # Build an ARFF formatter. 179 formatter = ARFF_Formatter.from_train(featuresets) 180 181 temp_dir = tempfile.mkdtemp() 182 try: 183 # Write the training data file. 184 train_filename = os.path.join(temp_dir, 'train.arff') 185 formatter.write(train_filename, featuresets) 186 187 if classifier in cls._CLASSIFIER_CLASS: 188 javaclass = cls._CLASSIFIER_CLASS[classifier] 189 elif classifier in cls._CLASSIFIER_CLASS.values(): 190 javaclass = classifier 191 else: 192 raise ValueError('Unknown classifier %s' % classifier) 193 194 # Train the weka model. 195 cmd = [javaclass, '-d', model_filename, '-t', train_filename] 196 cmd += list(options) 197 if quiet: stdout = subprocess.PIPE 198 else: stdout = None 199 java(cmd, classpath=_weka_classpath, stdout=stdout) 200 201 # Return the new classifier. 202 return WekaClassifier(formatter, model_filename) 203 204 finally: 205 for f in os.listdir(temp_dir): 206 os.remove(os.path.join(temp_dir, f)) 207 os.rmdir(temp_dir)
208 209
210 -class ARFF_Formatter:
211 """ 212 Converts featuresets and labeled featuresets to ARFF-formatted 213 strings, appropriate for input into Weka. 214 """
215 - def __init__(self, labels, features):
216 """ 217 @param labels: A list of all labels that can be generated. 218 @param features: A list of feature specifications, where 219 each feature specification is a tuple (fname, ftype); 220 and ftype is an ARFF type string such as NUMERIC or 221 STRING. 222 """ 223 self._labels = labels 224 self._features = features
225
226 - def format(self, tokens):
227 return self.header_section() + self.data_section(tokens)
228
229 - def labels(self):
230 return list(self._labels)
231
232 - def write(self, filename, tokens):
233 f = open(filename, 'w') 234 f.write(self.format(tokens)) 235 f.close()
236 237 @staticmethod
238 - def from_train(tokens):
239 # Find the set of all attested labels. 240 labels = set(label for (tok,label) in tokens) 241 242 # Determine the types of all features. 243 features = {} 244 for tok, label in tokens: 245 for (fname, fval) in tok.items(): 246 if issubclass(type(fval), bool): 247 ftype = '{True, False}' 248 elif issubclass(type(fval), (int, float, long, bool)): 249 ftype = 'NUMERIC' 250 elif issubclass(type(fval), basestring): 251 ftype = 'STRING' 252 elif fval is None: 253 continue # can't tell the type. 254 else: 255 raise ValueError('Unsupported value type %r' % ftype) 256 257 if features.get(fname, ftype) != ftype: 258 raise ValueError('Inconsistent type for %s' % fname) 259 features[fname] = ftype 260 features = sorted(features.items()) 261 262 return ARFF_Formatter(labels, features)
263
264 - def header_section(self):
265 # Header comment. 266 s = ('% Weka ARFF file\n' + 267 '% Generated automatically by NLTK\n' + 268 '%% %s\n\n' % time.ctime()) 269 270 # Relation name 271 s += '@RELATION rel\n\n' 272 273 # Input attribute specifications 274 for fname, ftype in self._features: 275 s += '@ATTRIBUTE %-30r %s\n' % (fname, ftype) 276 277 # Label attribute specification 278 s += '@ATTRIBUTE %-30r {%s}\n' % ('-label-', ','.join(self._labels)) 279 280 return s
281
282 - def data_section(self, tokens, labeled=None):
283 """ 284 @param labeled: Indicates whether the given tokens are labeled 285 or not. If C{None}, then the tokens will be assumed to be 286 labeled if the first token's value is a tuple or list. 287 """ 288 # Check if the tokens are labeled or unlabeled. If unlabeled, 289 # then use 'None' 290 if labeled is None: 291 labeled = tokens and isinstance(tokens[0], (tuple, list)) 292 if not labeled: 293 tokens = [(tok, None) for tok in tokens] 294 295 # Data section 296 s = '\n@DATA\n' 297 for (tok, label) in tokens: 298 for fname, ftype in self._features: 299 s += '%s,' % self._fmt_arff_val(tok.get(fname)) 300 s += '%s\n' % self._fmt_arff_val(label) 301 302 return s
303
304 - def _fmt_arff_val(self, fval):
305 if fval is None: 306 return '?' 307 elif isinstance(fval, (bool, int, long)): 308 return '%s' % fval 309 elif isinstance(fval, float): 310 return '%r' % fval 311 else: 312 return '%r' % fval
313 314 if __name__ == '__main__': 315 from nltk.classify.util import names_demo,binary_names_demo_features
316 - def make_classifier(featuresets):
317 return WekaClassifier.train('/tmp/name.model', featuresets, 318 'C4.5')
319 classifier = names_demo(make_classifier,binary_names_demo_features) 320