Package nltk :: Package tag :: Module crf
[hide private]
[frames] | no frames]

Source Code for Module nltk.tag.crf

  1  # Natural Language Toolkit: Conditional Random Fields 
  2  # 
  3  # Copyright (C) 2001-2008 NLTK Project 
  4  # Author: Edward Loper <[email protected]> 
  5  # URL: <http://nltk.org> 
  6  # For license information, see LICENSE.TXT 
  7  # 
  8  # $Id: hmm.py 5994 2008-06-02 12:07:07Z stevenbird $ 
  9   
 10   
 11  """ 
 12  An interface to U{Mallet <http://mallet.cs.umass.edu/>}'s Linear Chain 
 13  Conditional Random Field (LC-CRF) implementation. 
 14   
 15  A user-supplied I{feature detector function} is used to convert each 
 16  token to a featureset.  Each feature/value pair is then encoded as a 
 17  single binary feature for Mallet. 
 18  """ 
 19   
 20  from tempfile import * 
 21  import textwrap 
 22  import re 
 23  import time 
 24  import subprocess 
 25  import sys 
 26  import zipfile 
 27  import pickle 
 28   
 29  from nltk.classify.maxent import * 
 30  from nltk.classify.mallet import call_mallet 
 31  from nltk.etree import ElementTree 
 32   
 33  from api import * 
 34   
35 -class MalletCRF(FeaturesetTaggerI):
36 """ 37 A conditional random field tagger, which is trained and run by 38 making external calls to Mallet. Tokens are converted to 39 featuresets using a feature detector function:: 40 41 feature_detector(tokens, index) -> featureset 42 43 These featuresets are then encoded into feature vectors by 44 converting each feature (name, value) pair to a unique binary 45 feature. 46 47 Ecah C{MalletCRF} object is backed by a X{crf model file}. This 48 model file is actually a zip file, and it contains one file for 49 the serialized model (C{crf-model.ser}) and one file for 50 information about the structure of the CRF (C{crf-info.xml}). 51 """ 52
53 - def __init__(self, filename, feature_detector=None):
54 """ 55 Create a new C{MalletCRF}. 56 57 @param filename: The filename of the model file that backs 58 this CRF. 59 @param feature_detector: The feature detector function that is 60 used to convert tokens to featuresets. This parameter 61 only needs to be given if the model file does not contain 62 a pickled pointer to the feature detector (e.g., if the 63 feature detector was a lambda function). 64 """ 65 # Read the CRFInfo from the model file. 66 zf = zipfile.ZipFile(filename) 67 crf_info = CRFInfo.fromstring(zf.read('crf-info.xml')) 68 zf.close() 69 70 self.crf_info = crf_info 71 """A L{CRFInfo} object describing this CRF.""" 72 73 # Ensure that our crf_info object has a feature detector. 74 if crf_info.feature_detector is not None: 75 if (feature_detector is not None and 76 self.crf_info.feature_detector != feature_detector): 77 raise ValueError('Feature detector mismatch: %r vs %r' % 78 (feature_detector, self.crf_info.feature_detector)) 79 elif feature_detector is None: 80 raise ValueError('Feature detector not found; supply it manually.') 81 elif feature_detector.__name__ != crf_info.feature_detector_name: 82 raise ValueError('Feature detector name mismatch: %r vs %r' % 83 (feature_detector.__name__, 84 crf_info.feature_detector_name)) 85 else: 86 self.crf_info.feature_detector = feature_detector
87 88 #///////////////////////////////////////////////////////////////// 89 # Convenience accessors (info also available via self.crf_info) 90 #///////////////////////////////////////////////////////////////// 91
92 - def _get_filename(self):
93 return self.crf_info.model_filename
94 filename = property(_get_filename , doc=""" 95 The filename of the crf model file that backs this 96 C{MalletCRF}. The crf model file is actually a zip file, and 97 it contains one file for the serialized model 98 (C{crf-model.ser}) and one file for information about the 99 structure of the CRF (C{crf-info.xml}).""") 100
101 - def _get_feature_detector(self):
102 return self.crf_info.model_feature_detector
103 feature_detector = property(_get_feature_detector , doc=""" 104 The feature detector function that is used to convert tokens 105 to featuresets. This function has the signature:: 106 107 feature_detector(tokens, index) -> featureset""") 108 109 #///////////////////////////////////////////////////////////////// 110 # Tagging 111 #///////////////////////////////////////////////////////////////// 112 113 #: The name of the java script used to run MalletCRFs. 114 _RUN_CRF = "org.nltk.mallet.RunCRF" 115
116 - def batch_tag(self, sentences):
117 # Write the test corpus to a temporary file 118 (fd, test_file) = mkstemp('.txt', 'test') 119 self.write_test_corpus(sentences, os.fdopen(fd, 'w')) 120 121 try: 122 # Run mallet on the test file. 123 stdout, stderr = call_mallet([self._RUN_CRF, 124 '--model-file', os.path.abspath(self.crf_info.model_filename), 125 '--test-file', test_file], stdout='pipe') 126 127 # Decode the output 128 labels = self.parse_mallet_output(stdout) 129 130 # strip __start__ and __end__ 131 if self.crf_info.add_start_state and self.crf_info.add_end_state: 132 labels = [labs[1:-1] for labs in labels] 133 elif self.crf_info.add_start_state: 134 labels = [labs[1:] for labs in labels] 135 elif self.crf_info.add_end_state: 136 labels = [labs[:-1] for labs in labels] 137 138 # Combine the labels and the original sentences. 139 return [zip(sent, label) for (sent,label) in 140 zip(sentences, labels)] 141 142 finally: 143 os.remove(test_file)
144 145 #///////////////////////////////////////////////////////////////// 146 # Training 147 #///////////////////////////////////////////////////////////////// 148 149 #: The name of the java script used to train MalletCRFs. 150 _TRAIN_CRF = "org.nltk.mallet.TrainCRF" 151 152 @classmethod
153 - def train(cls, feature_detector, corpus, filename=None, 154 weight_groups=None, gaussian_variance=1, default_label='O', 155 transduction_type='VITERBI', max_iterations=500, 156 add_start_state=True, add_end_state=True, trace=1):
157 """ 158 Train a new linear chain CRF tagger based on the given corpus 159 of training sequences. This tagger will be backed by a I{crf 160 model file}, containing both a serialized Mallet model and 161 information about the CRF's structure. This crf model file 162 will I{not} be automatically deleted -- if you wish to delete 163 it, you must delete it manually. The filename of the model 164 file for a MalletCRF C{crf} is available as C{crf.filename}. 165 166 167 @type corpus: C{list} of C{tuple} 168 @param corpus: Training data, represented as a list of 169 sentences, where each sentence is a list of (token, tag) 170 tuples. 171 172 @type filename: C{str} 173 @param filename: The filename that should be used for the crf 174 model file that backs the new C{MalletCRF}. If no 175 filename is given, then a new filename will be chosen 176 automatically. 177 178 @type weight_groups: C{list} of L{CRFInfo.WeightGroup} 179 @param weight_groups: Specifies how input-features should 180 be mapped to joint-features. See L{CRFInfo.WeightGroup} 181 for more information. 182 183 @type gaussian_variance: C{float} 184 @param gaussian_variance: The gaussian variance of the prior 185 that should be used to train the new CRF. 186 187 @type default_label: C{str} 188 @param default_label: The "label for initial context and 189 uninteresting tokens" (from Mallet's SimpleTagger.java.) 190 It's unclear whether this currently has any effect. 191 192 @type transduction_type: C{str} 193 @param transduction_type: The type of transduction used by 194 the CRF. Can be VITERBI, VITERBI_FBEAM, VITERBI_BBEAM, 195 VITERBI_FBBEAM, or VITERBI_FBEAMKL. 196 197 @type max_iterations: C{int} 198 @param max_iterations: The maximum number of iterations that 199 should be used for training the CRF. 200 201 @type add_start_state: C{bool} 202 @param add_start_state: If true, then NLTK will add a special 203 start state, named C{'__start__'}. The initial cost for 204 the start state will be set to 0; and the initial cost for 205 all other states will be set to +inf. 206 207 @type add_end_state: C{bool} 208 @param add_end_state: If true, then NLTK will add a special 209 end state, named C{'__end__'}. The final cost for the end 210 state will be set to 0; and the final cost for all other 211 states will be set to +inf. 212 213 @type trace: C{int} 214 @param trace: Controls the verbosity of trace output generated 215 while training the CRF. Higher numbers generate more verbose 216 output. 217 """ 218 t0 = time.time() # Record starting time. 219 220 # If they did not supply a model filename, then choose one. 221 if filename is None: 222 (fd, filename) = mkstemp('.crf', 'model') 223 os.fdopen(fd).close() 224 225 # Ensure that the filename ends with '.zip' 226 if not filename.endswith('.crf'): 227 filename += '.crf' 228 229 if trace >= 1: 230 print '[MalletCRF] Training a new CRF: %s' % filename 231 232 # Create crf-info object describing the new CRF. 233 crf_info = MalletCRF._build_crf_info( 234 corpus, gaussian_variance, default_label, max_iterations, 235 transduction_type, weight_groups, add_start_state, 236 add_end_state, filename, feature_detector) 237 238 # Create a zipfile, and write crf-info to it. 239 if trace >= 2: 240 print '[MalletCRF] Adding crf-info.xml to %s' % filename 241 zf = zipfile.ZipFile(filename, mode='w') 242 zf.writestr('crf-info.xml', crf_info.toxml()+'\n') 243 zf.close() 244 245 # Create the CRF object. 246 crf = MalletCRF(filename, feature_detector) 247 248 # Write the Training corpus to a temporary file. 249 if trace >= 2: 250 print '[MalletCRF] Writing training corpus...' 251 (fd, train_file) = mkstemp('.txt', 'train') 252 crf.write_training_corpus(corpus, os.fdopen(fd, 'w')) 253 254 try: 255 if trace >= 1: 256 print '[MalletCRF] Calling mallet to train CRF...' 257 cmd = [MalletCRF._TRAIN_CRF, 258 '--model-file', os.path.abspath(filename), 259 '--train-file', train_file] 260 if trace > 3: 261 call_mallet(cmd) 262 else: 263 p = call_mallet(cmd, stdout=subprocess.PIPE, 264 stderr=subprocess.STDOUT, 265 blocking=False) 266 MalletCRF._filter_training_output(p, trace) 267 finally: 268 # Delete the temp file containing the training corpus. 269 os.remove(train_file) 270 271 if trace >= 1: 272 print '[MalletCRF] Training complete.' 273 print '[MalletCRF] Model stored in: %s' % filename 274 if trace >= 2: 275 dt = time.time()-t0 276 print '[MalletCRF] Total training time: %d seconds' % dt 277 278 # Return the completed CRF. 279 return crf
280 281 @staticmethod
282 - def _build_crf_info(corpus, gaussian_variance, default_label, 283 max_iterations, transduction_type, weight_groups, 284 add_start_state, add_end_state, 285 model_filename, feature_detector):
286 """ 287 Construct a C{CRFInfo} object describing a CRF with a given 288 set of configuration parameters, and based on the contents of 289 a given corpus. 290 """ 291 state_info_list = [] 292 293 labels = set() 294 if add_start_state: 295 labels.add('__start__') 296 if add_end_state: 297 labels.add('__end__') 298 transitions = set() # not necessary to find this? 299 for sent in corpus: 300 prevtag = default_label 301 for (tok,tag) in sent: 302 labels.add(tag) 303 transitions.add( (prevtag, tag) ) 304 prevtag = tag 305 if add_start_state: 306 transitions.add( ('__start__', sent[0][1]) ) 307 if add_end_state: 308 transitions.add( (sent[-1][1], '__end__') ) 309 labels = sorted(labels) 310 311 # 0th order default: 312 if weight_groups is None: 313 weight_groups = [CRFInfo.WeightGroup(name=l, src='.*', 314 dst=re.escape(l)) 315 for l in labels] 316 317 # Check that weight group names are unique 318 if len(weight_groups) != len(set(wg.name for wg in weight_groups)): 319 raise ValueError("Weight group names must be unique") 320 321 # Construct a list of state descriptions. Currently, we make 322 # these states fully-connected, with one parameter per 323 # transition. 324 for src in labels: 325 if add_start_state: 326 if src == '__start__': 327 initial_cost = 0 328 else: 329 initial_cost = '+inf' 330 if add_end_state: 331 if src == '__end__': 332 final_cost = 0 333 else: 334 final_cost = '+inf' 335 state_info = CRFInfo.State(src, initial_cost, final_cost, []) 336 for dst in labels: 337 state_weight_groups = [wg.name for wg in weight_groups 338 if wg.match(src, dst)] 339 state_info.transitions.append( 340 CRFInfo.Transition(dst, dst, state_weight_groups)) 341 state_info_list.append(state_info) 342 343 return CRFInfo(state_info_list, gaussian_variance, 344 default_label, max_iterations, 345 transduction_type, weight_groups, 346 add_start_state, add_end_state, 347 model_filename, feature_detector)
348 349 #: A table used to filter the output that mallet generates during 350 #: training. By default, mallet generates very verbose output. 351 #: This table is used to select which lines of output are actually 352 #: worth displaying to the user, based on the level of the C{trace} 353 #: parameter. Each entry of this table is a tuple 354 #: C{(min_trace_level, regexp)}. A line will be displayed only if 355 #: C{trace>=min_trace_level} and the line matches C{regexp} for at 356 #: least one table entry. 357 _FILTER_TRAINING_OUTPUT = [ 358 (1, r'DEBUG:.*'), 359 (1, r'Number of weights.*'), 360 (1, r'CRF about to train.*'), 361 (1, r'CRF finished.*'), 362 (1, r'CRF training has converged.*'), 363 (2, r'CRF weights.*'), 364 (2, r'getValue\(\) \(loglikelihood\) .*'), 365 ] 366 367 @staticmethod
368 - def _filter_training_output(p, trace):
369 """ 370 Filter the (very verbose) output that is generated by mallet, 371 and only display the interesting lines. The lines that are 372 selected for display are determined by 373 L{_FILTER_TRAINING_OUTPUT}. 374 """ 375 out = [] 376 while p.poll() is None: 377 while True: 378 line = p.stdout.readline() 379 if not line: break 380 out.append(line) 381 for (t, regexp) in MalletCRF._FILTER_TRAINING_OUTPUT: 382 if t <= trace and re.match(regexp, line): 383 indent = ' '*t 384 print '[MalletCRF] %s%s' % (indent, line.rstrip()) 385 break 386 if p.returncode != 0: 387 print "\nError encountered! Mallet's most recent output:" 388 print ''.join(out[-100:]) 389 raise OSError('Mallet command failed')
390 391 392 #///////////////////////////////////////////////////////////////// 393 # Communication w/ mallet 394 #///////////////////////////////////////////////////////////////// 395
396 - def write_training_corpus(self, corpus, stream, close_stream=True):
397 """ 398 Write a given training corpus to a given stream, in a format that 399 can be read by the java script C{org.nltk.mallet.TrainCRF}. 400 """ 401 feature_detector = self.crf_info.feature_detector 402 for sentence in corpus: 403 if self.crf_info.add_start_state: 404 stream.write('__start__ __start__\n') 405 unlabeled_sent = [tok for (tok,tag) in sentence] 406 for index in range(len(unlabeled_sent)): 407 featureset = feature_detector(unlabeled_sent, index) 408 for (fname, fval) in featureset.items(): 409 stream.write(self._format_feature(fname, fval)+" ") 410 stream.write(sentence[index][1]+'\n') 411 if self.crf_info.add_end_state: 412 stream.write('__end__ __end__\n') 413 stream.write('\n') 414 if close_stream: stream.close()
415
416 - def write_test_corpus(self, corpus, stream, close_stream=True):
417 """ 418 Write a given test corpus to a given stream, in a format that 419 can be read by the java script C{org.nltk.mallet.TestCRF}. 420 """ 421 feature_detector = self.crf_info.feature_detector 422 for sentence in corpus: 423 if self.crf_info.add_start_state: 424 stream.write('__start__ __start__\n') 425 for index in range(len(sentence)): 426 featureset = feature_detector(sentence, index) 427 for (fname, fval) in featureset.items(): 428 stream.write(self._format_feature(fname, fval)+" ") 429 stream.write('\n') 430 if self.crf_info.add_end_state: 431 stream.write('__end__ __end__\n') 432 stream.write('\n') 433 if close_stream: stream.close()
434
435 - def parse_mallet_output(self, s):
436 """ 437 Parse the output that is generated by the java script 438 C{org.nltk.mallet.TestCRF}, and convert it to a labeled 439 corpus. 440 """ 441 if re.match(r'\s*<<start>>', s): 442 assert 0, 'its a lattice' 443 corpus = [[]] 444 for line in s.split('\n'): 445 line = line.strip() 446 # Label with augmentations? 447 if line: 448 corpus[-1].append(line.strip()) 449 # Start of new instance? 450 elif corpus[-1] != []: 451 corpus.append([]) 452 if corpus[-1] == []: corpus.pop() 453 return corpus
454 455 _ESCAPE_RE = re.compile('[^a-zA-Z0-9]') 456 @staticmethod
457 - def _escape_sub(m):
458 return '%' + ('%02x' % ord(m.group()))
459 460 @staticmethod
461 - def _format_feature(fname, fval):
462 """ 463 Return a string name for a given feature (name, value) pair, 464 appropriate for consumption by mallet. We escape every 465 character in fname or fval that's not a letter or a number, 466 just to be conservative. 467 """ 468 fname = MalletCRF._ESCAPE_RE.sub(MalletCRF._escape_sub, fname) 469 if isinstance(fval, basestring): 470 fval = "'%s'" % MalletCRF._ESCAPE_RE.sub( 471 MalletCRF._escape_sub, fval) 472 else: 473 fval = MalletCRF._ESCAPE_RE.sub(MalletCRF._escape_sub, '%r'%fval) 474 return fname+'='+fval
475 476 #///////////////////////////////////////////////////////////////// 477 # String Representation 478 #///////////////////////////////////////////////////////////////// 479
480 - def __repr__(self):
481 return 'MalletCRF(%r)' % self.crf_info.model_filename
482 483 ########################################################################### 484 ## Serializable CRF Information Object 485 ########################################################################### 486
487 -class CRFInfo(object):
488 """ 489 An object used to record configuration information about a 490 MalletCRF object. This configuration information can be 491 serialized to an XML file, which can then be read by NLTK's custom 492 interface to Mallet's CRF. 493 494 CRFInfo objects are typically created by the L{MalletCRF.train()} 495 method. 496 497 Advanced users may wish to directly create custom 498 C{CRFInfo.WeightGroup} objects and pass them to the 499 L{MalletCRF.train()} function. See L{CRFInfo.WeightGroup} for 500 more information. 501 """
502 - def __init__(self, states, gaussian_variance, default_label, 503 max_iterations, transduction_type, weight_groups, 504 add_start_state, add_end_state, model_filename, 505 feature_detector):
506 self.gaussian_variance = float(gaussian_variance) 507 self.default_label = default_label 508 self.states = states 509 self.max_iterations = max_iterations 510 self.transduction_type = transduction_type 511 self.weight_groups = weight_groups 512 self.add_start_state = add_start_state 513 self.add_end_state = add_end_state 514 self.model_filename = model_filename 515 if isinstance(feature_detector, basestring): 516 self.feature_detector_name = feature_detector 517 self.feature_detector = None 518 else: 519 self.feature_detector_name = feature_detector.__name__ 520 self.feature_detector = feature_detector
521 522 _XML_TEMPLATE = ( 523 '<crf>\n' 524 ' <modelFile>%(model_filename)s</modelFile>\n' 525 ' <gaussianVariance>%(gaussian_variance)d</gaussianVariance>\n' 526 ' <defaultLabel>%(default_label)s</defaultLabel>\n' 527 ' <maxIterations>%(max_iterations)s</maxIterations>\n' 528 ' <transductionType>%(transduction_type)s</transductionType>\n' 529 ' <featureDetector name="%(feature_detector_name)s">\n' 530 ' %(feature_detector)s\n' 531 ' </featureDetector>\n' 532 ' <addStartState>%(add_start_state)s</addStartState>\n' 533 ' <addEndState>%(add_end_state)s</addEndState>\n' 534 ' <states>\n' 535 '%(states)s\n' 536 ' </states>\n' 537 ' <weightGroups>\n' 538 '%(w_groups)s\n' 539 ' </weightGroups>\n' 540 '</crf>\n') 541
542 - def toxml(self):
543 info = self.__dict__.copy() 544 info['states'] = '\n'.join(state.toxml() for state in self.states) 545 info['w_groups'] = '\n'.join(wg.toxml() for wg in self.weight_groups) 546 info['feature_detector_name'] = (info['feature_detector_name'] 547 .replace('&', '&amp;') 548 .replace('<', '&lt;')) 549 try: 550 fd = pickle.dumps(self.feature_detector) 551 fd = fd.replace('&', '&amp;').replace('<', '&lt;') 552 fd = fd.replace('\n', '&#10;') # put pickle data all on 1 line. 553 info['feature_detector'] = '<pickle>%s</pickle>' % fd 554 except pickle.PicklingError: 555 info['feature_detector'] = '' 556 return self._XML_TEMPLATE % info
557 558 @staticmethod
559 - def fromstring(s):
561 562 @staticmethod
563 - def _read(etree):
564 states = [CRFInfo.State._read(et) for et in 565 etree.findall('states/state')] 566 weight_groups = [CRFInfo.WeightGroup._read(et) for et in 567 etree.findall('weightGroups/weightGroup')] 568 fd = etree.find('featureDetector') 569 feature_detector = fd.get('name') 570 if fd.find('pickle') is not None: 571 try: feature_detector = pickle.loads(fd.find('pickle').text) 572 except pickle.PicklingError, e: pass # unable to unpickle it. 573 574 return CRFInfo(states, 575 float(etree.find('gaussianVariance').text), 576 etree.find('defaultLabel').text, 577 int(etree.find('maxIterations').text), 578 etree.find('transductionType').text, 579 weight_groups, 580 bool(etree.find('addStartState').text), 581 bool(etree.find('addEndState').text), 582 etree.find('modelFile').text, 583 feature_detector)
584
585 - def write(self, filename):
586 out = open(filename, 'w') 587 out.write(self.toxml()) 588 out.write('\n') 589 out.close()
590
591 - class State(object):
592 """ 593 A description of a single CRF state. 594 """
595 - def __init__(self, name, initial_cost, final_cost, transitions):
596 if initial_cost != '+inf': initial_cost = float(initial_cost) 597 if final_cost != '+inf': final_cost = float(final_cost) 598 self.name = name 599 self.initial_cost = initial_cost 600 self.final_cost = final_cost 601 self.transitions = transitions
602 603 _XML_TEMPLATE = ( 604 ' <state name="%(name)s" initialCost="%(initial_cost)s" ' 605 'finalCost="%(final_cost)s">\n' 606 ' <transitions>\n' 607 '%(transitions)s\n' 608 ' </transitions>\n' 609 ' </state>\n')
610 - def toxml(self):
611 info = self.__dict__.copy() 612 info['transitions'] = '\n'.join(transition.toxml() 613 for transition in self.transitions) 614 return self._XML_TEMPLATE % info
615 616 @staticmethod
617 - def _read(etree):
618 transitions = [CRFInfo.Transition._read(et) 619 for et in etree.findall('transitions/transition')] 620 return CRFInfo.State(etree.get('name'), 621 etree.get('initialCost'), 622 etree.get('finalCost'), 623 transitions)
624
625 - class Transition(object):
626 """ 627 A description of a single CRF transition. 628 """
629 - def __init__(self, destination, label, weightgroups):
630 """ 631 @param destination: The name of the state that this transition 632 connects to. 633 @param label: The tag that is generated when traversing this 634 transition. 635 @param weightgroups: A list of L{WeightGroup} names, indicating 636 which weight groups should be used to calculate the cost 637 of traversing this transition. 638 """ 639 self.destination = destination 640 self.label = label 641 self.weightgroups = weightgroups
642 643 _XML_TEMPLATE = (' <transition label="%(label)s" ' 644 'destination="%(destination)s" ' 645 'weightGroups="%(w_groups)s"/>')
646 - def toxml(self):
647 info = self.__dict__ 648 info['w_groups'] = ' '.join(wg for wg in self.weightgroups) 649 return self._XML_TEMPLATE % info
650 651 @staticmethod
652 - def _read(etree):
653 return CRFInfo.Transition(etree.get('destination'), 654 etree.get('label'), 655 etree.get('weightGroups').split())
656
657 - class WeightGroup(object):
658 """ 659 A configuration object used by C{MalletCRF} to specify how 660 input-features (which are a function of only the input) should be 661 mapped to joint-features (which are a function of both the input 662 and the output tags). 663 664 Each weight group specifies that a given set of input features 665 should be paired with all transitions from a given set of source 666 tags to a given set of destination tags. 667 """
668 - def __init__(self, name, src, dst, features='.*'):
669 """ 670 @param name: A unique name for this weight group. 671 @param src: The set of source tags that should be used for 672 this weight group, specified as either a list of state 673 names or a regular expression. 674 @param dst: The set of destination tags that should be used 675 for this weight group, specified as either a list of state 676 names or a regular expression. 677 @param features: The set of input feature that should be used 678 for this weight group, specified as either a list of 679 feature names or a regular expression. WARNING: currently, 680 this regexp is passed streight to java -- i.e., it must 681 be a java-style regexp! 682 """ 683 if re.search('\s', name): 684 raise ValueError('weight group name may not ' 685 'contain whitespace.') 686 if re.search('"', name): 687 raise ValueError('weight group name may not contain \'"\'.') 688 self.name = name 689 self.src = src 690 self.dst = dst 691 self.features = features 692 self._src_match_cache = {} 693 self._dst_match_cache = {}
694 695 _XML_TEMPLATE = (' <weightGroup name="%(name)s" src="%(src)s" ' 696 'dst="%(dst)s" features="%(features)s" />')
697 - def toxml(self):
698 return self._XML_TEMPLATE % self.__dict__
699 700 @staticmethod
701 - def _read(etree):
702 return CRFInfo.WeightGroup(etree.get('name'), 703 etree.get('src'), 704 etree.get('dst'), 705 etree.get('features'))
706 707 # [xx] feature name????
708 - def match(self, src, dst):
709 # Check if the source matches 710 src_match = self._src_match_cache.get(src) 711 if src_match is None: 712 if isinstance(self.src, basestring): 713 src_match = bool(re.match(self.src+'\Z', src)) 714 else: 715 src_match = src in self.src 716 self._src_match_cache[src] = src_match 717 718 # Check if the dest matches 719 dst_match = self._dst_match_cache.get(dst) 720 if dst_match is None: 721 if isinstance(self.dst, basestring): 722 dst_match = bool(re.match(self.dst+'\Z', dst)) 723 else: 724 dst_match = dst in self.dst 725 self._dst_match_cache[dst] = dst_match 726 727 # Return true if both matched. 728 return src_match and dst_match
729 730 ########################################################################### 731 ## Demonstration code 732 ########################################################################### 733
734 -def demo(train_size=100, test_size=100, 735 java_home='/usr/local/jdk1.5.0/', 736 mallet_home='/usr/local/mallet-0.4'):
737 from nltk.corpus import brown 738 import textwrap 739 740 # Define a very simple feature detector 741 def fd(sentence, index): 742 word = sentence[index] 743 return dict(word=word, suffix=word[-2:], len=len(word))
744 745 # Let nltk know where java & mallet are. 746 nltk.internals.config_java(java_home) 747 nltk.classify.mallet.config_mallet(mallet_home) 748 749 # Get the training & test corpus. We simplify the tagset a little: 750 # just the first 2 chars. 751 def strip(corpus): return [[(w, t[:2]) for (w,t) in sent] 752 for sent in corpus] 753 brown_train = strip(brown.tagged_sents(categories='a')[:train_size]) 754 brown_test = strip(brown.tagged_sents(categories='b')[:test_size]) 755 756 crf = MalletCRF.train(fd, brown_train, #'/tmp/crf-model', 757 transduction_type='VITERBI') 758 sample_output = crf.tag([w for (w,t) in brown_test[5]]) 759 acc = nltk.tag.accuracy(crf, brown_test) 760 print '\nAccuracy: %.1f%%' % (acc*100) 761 print 'Sample output:' 762 print textwrap.fill(' '.join('%s/%s' % w for w in sample_output), 763 initial_indent=' ', subsequent_indent=' ')+'\n' 764 765 # Clean up 766 print 'Clean-up: deleting', crf.filename 767 os.remove(crf.filename) 768 769 if __name__ == '__main__': 770 demo(train_size=100) 771