1
2
3
4
5
6
7
8
9
10
11 """
12 An interface to U{Mallet <http://mallet.cs.umass.edu/>}'s Linear Chain
13 Conditional Random Field (LC-CRF) implementation.
14
15 A user-supplied I{feature detector function} is used to convert each
16 token to a featureset. Each feature/value pair is then encoded as a
17 single binary feature for Mallet.
18 """
19
20 from tempfile import *
21 import textwrap
22 import re
23 import time
24 import subprocess
25 import sys
26 import zipfile
27 import pickle
28
29 from nltk.classify.maxent import *
30 from nltk.classify.mallet import call_mallet
31 from nltk.etree import ElementTree
32
33 from api import *
34
36 """
37 A conditional random field tagger, which is trained and run by
38 making external calls to Mallet. Tokens are converted to
39 featuresets using a feature detector function::
40
41 feature_detector(tokens, index) -> featureset
42
43 These featuresets are then encoded into feature vectors by
44 converting each feature (name, value) pair to a unique binary
45 feature.
46
47 Ecah C{MalletCRF} object is backed by a X{crf model file}. This
48 model file is actually a zip file, and it contains one file for
49 the serialized model (C{crf-model.ser}) and one file for
50 information about the structure of the CRF (C{crf-info.xml}).
51 """
52
53 - def __init__(self, filename, feature_detector=None):
54 """
55 Create a new C{MalletCRF}.
56
57 @param filename: The filename of the model file that backs
58 this CRF.
59 @param feature_detector: The feature detector function that is
60 used to convert tokens to featuresets. This parameter
61 only needs to be given if the model file does not contain
62 a pickled pointer to the feature detector (e.g., if the
63 feature detector was a lambda function).
64 """
65
66 zf = zipfile.ZipFile(filename)
67 crf_info = CRFInfo.fromstring(zf.read('crf-info.xml'))
68 zf.close()
69
70 self.crf_info = crf_info
71 """A L{CRFInfo} object describing this CRF."""
72
73
74 if crf_info.feature_detector is not None:
75 if (feature_detector is not None and
76 self.crf_info.feature_detector != feature_detector):
77 raise ValueError('Feature detector mismatch: %r vs %r' %
78 (feature_detector, self.crf_info.feature_detector))
79 elif feature_detector is None:
80 raise ValueError('Feature detector not found; supply it manually.')
81 elif feature_detector.__name__ != crf_info.feature_detector_name:
82 raise ValueError('Feature detector name mismatch: %r vs %r' %
83 (feature_detector.__name__,
84 crf_info.feature_detector_name))
85 else:
86 self.crf_info.feature_detector = feature_detector
87
88
89
90
91
93 return self.crf_info.model_filename
94 filename = property(_get_filename , doc="""
95 The filename of the crf model file that backs this
96 C{MalletCRF}. The crf model file is actually a zip file, and
97 it contains one file for the serialized model
98 (C{crf-model.ser}) and one file for information about the
99 structure of the CRF (C{crf-info.xml}).""")
100
102 return self.crf_info.model_feature_detector
103 feature_detector = property(_get_feature_detector , doc="""
104 The feature detector function that is used to convert tokens
105 to featuresets. This function has the signature::
106
107 feature_detector(tokens, index) -> featureset""")
108
109
110
111
112
113
114 _RUN_CRF = "org.nltk.mallet.RunCRF"
115
117
118 (fd, test_file) = mkstemp('.txt', 'test')
119 self.write_test_corpus(sentences, os.fdopen(fd, 'w'))
120
121 try:
122
123 stdout, stderr = call_mallet([self._RUN_CRF,
124 '--model-file', os.path.abspath(self.crf_info.model_filename),
125 '--test-file', test_file], stdout='pipe')
126
127
128 labels = self.parse_mallet_output(stdout)
129
130
131 if self.crf_info.add_start_state and self.crf_info.add_end_state:
132 labels = [labs[1:-1] for labs in labels]
133 elif self.crf_info.add_start_state:
134 labels = [labs[1:] for labs in labels]
135 elif self.crf_info.add_end_state:
136 labels = [labs[:-1] for labs in labels]
137
138
139 return [zip(sent, label) for (sent,label) in
140 zip(sentences, labels)]
141
142 finally:
143 os.remove(test_file)
144
145
146
147
148
149
150 _TRAIN_CRF = "org.nltk.mallet.TrainCRF"
151
152 @classmethod
153 - def train(cls, feature_detector, corpus, filename=None,
154 weight_groups=None, gaussian_variance=1, default_label='O',
155 transduction_type='VITERBI', max_iterations=500,
156 add_start_state=True, add_end_state=True, trace=1):
157 """
158 Train a new linear chain CRF tagger based on the given corpus
159 of training sequences. This tagger will be backed by a I{crf
160 model file}, containing both a serialized Mallet model and
161 information about the CRF's structure. This crf model file
162 will I{not} be automatically deleted -- if you wish to delete
163 it, you must delete it manually. The filename of the model
164 file for a MalletCRF C{crf} is available as C{crf.filename}.
165
166
167 @type corpus: C{list} of C{tuple}
168 @param corpus: Training data, represented as a list of
169 sentences, where each sentence is a list of (token, tag)
170 tuples.
171
172 @type filename: C{str}
173 @param filename: The filename that should be used for the crf
174 model file that backs the new C{MalletCRF}. If no
175 filename is given, then a new filename will be chosen
176 automatically.
177
178 @type weight_groups: C{list} of L{CRFInfo.WeightGroup}
179 @param weight_groups: Specifies how input-features should
180 be mapped to joint-features. See L{CRFInfo.WeightGroup}
181 for more information.
182
183 @type gaussian_variance: C{float}
184 @param gaussian_variance: The gaussian variance of the prior
185 that should be used to train the new CRF.
186
187 @type default_label: C{str}
188 @param default_label: The "label for initial context and
189 uninteresting tokens" (from Mallet's SimpleTagger.java.)
190 It's unclear whether this currently has any effect.
191
192 @type transduction_type: C{str}
193 @param transduction_type: The type of transduction used by
194 the CRF. Can be VITERBI, VITERBI_FBEAM, VITERBI_BBEAM,
195 VITERBI_FBBEAM, or VITERBI_FBEAMKL.
196
197 @type max_iterations: C{int}
198 @param max_iterations: The maximum number of iterations that
199 should be used for training the CRF.
200
201 @type add_start_state: C{bool}
202 @param add_start_state: If true, then NLTK will add a special
203 start state, named C{'__start__'}. The initial cost for
204 the start state will be set to 0; and the initial cost for
205 all other states will be set to +inf.
206
207 @type add_end_state: C{bool}
208 @param add_end_state: If true, then NLTK will add a special
209 end state, named C{'__end__'}. The final cost for the end
210 state will be set to 0; and the final cost for all other
211 states will be set to +inf.
212
213 @type trace: C{int}
214 @param trace: Controls the verbosity of trace output generated
215 while training the CRF. Higher numbers generate more verbose
216 output.
217 """
218 t0 = time.time()
219
220
221 if filename is None:
222 (fd, filename) = mkstemp('.crf', 'model')
223 os.fdopen(fd).close()
224
225
226 if not filename.endswith('.crf'):
227 filename += '.crf'
228
229 if trace >= 1:
230 print '[MalletCRF] Training a new CRF: %s' % filename
231
232
233 crf_info = MalletCRF._build_crf_info(
234 corpus, gaussian_variance, default_label, max_iterations,
235 transduction_type, weight_groups, add_start_state,
236 add_end_state, filename, feature_detector)
237
238
239 if trace >= 2:
240 print '[MalletCRF] Adding crf-info.xml to %s' % filename
241 zf = zipfile.ZipFile(filename, mode='w')
242 zf.writestr('crf-info.xml', crf_info.toxml()+'\n')
243 zf.close()
244
245
246 crf = MalletCRF(filename, feature_detector)
247
248
249 if trace >= 2:
250 print '[MalletCRF] Writing training corpus...'
251 (fd, train_file) = mkstemp('.txt', 'train')
252 crf.write_training_corpus(corpus, os.fdopen(fd, 'w'))
253
254 try:
255 if trace >= 1:
256 print '[MalletCRF] Calling mallet to train CRF...'
257 cmd = [MalletCRF._TRAIN_CRF,
258 '--model-file', os.path.abspath(filename),
259 '--train-file', train_file]
260 if trace > 3:
261 call_mallet(cmd)
262 else:
263 p = call_mallet(cmd, stdout=subprocess.PIPE,
264 stderr=subprocess.STDOUT,
265 blocking=False)
266 MalletCRF._filter_training_output(p, trace)
267 finally:
268
269 os.remove(train_file)
270
271 if trace >= 1:
272 print '[MalletCRF] Training complete.'
273 print '[MalletCRF] Model stored in: %s' % filename
274 if trace >= 2:
275 dt = time.time()-t0
276 print '[MalletCRF] Total training time: %d seconds' % dt
277
278
279 return crf
280
281 @staticmethod
282 - def _build_crf_info(corpus, gaussian_variance, default_label,
283 max_iterations, transduction_type, weight_groups,
284 add_start_state, add_end_state,
285 model_filename, feature_detector):
286 """
287 Construct a C{CRFInfo} object describing a CRF with a given
288 set of configuration parameters, and based on the contents of
289 a given corpus.
290 """
291 state_info_list = []
292
293 labels = set()
294 if add_start_state:
295 labels.add('__start__')
296 if add_end_state:
297 labels.add('__end__')
298 transitions = set()
299 for sent in corpus:
300 prevtag = default_label
301 for (tok,tag) in sent:
302 labels.add(tag)
303 transitions.add( (prevtag, tag) )
304 prevtag = tag
305 if add_start_state:
306 transitions.add( ('__start__', sent[0][1]) )
307 if add_end_state:
308 transitions.add( (sent[-1][1], '__end__') )
309 labels = sorted(labels)
310
311
312 if weight_groups is None:
313 weight_groups = [CRFInfo.WeightGroup(name=l, src='.*',
314 dst=re.escape(l))
315 for l in labels]
316
317
318 if len(weight_groups) != len(set(wg.name for wg in weight_groups)):
319 raise ValueError("Weight group names must be unique")
320
321
322
323
324 for src in labels:
325 if add_start_state:
326 if src == '__start__':
327 initial_cost = 0
328 else:
329 initial_cost = '+inf'
330 if add_end_state:
331 if src == '__end__':
332 final_cost = 0
333 else:
334 final_cost = '+inf'
335 state_info = CRFInfo.State(src, initial_cost, final_cost, [])
336 for dst in labels:
337 state_weight_groups = [wg.name for wg in weight_groups
338 if wg.match(src, dst)]
339 state_info.transitions.append(
340 CRFInfo.Transition(dst, dst, state_weight_groups))
341 state_info_list.append(state_info)
342
343 return CRFInfo(state_info_list, gaussian_variance,
344 default_label, max_iterations,
345 transduction_type, weight_groups,
346 add_start_state, add_end_state,
347 model_filename, feature_detector)
348
349
350
351
352
353
354
355
356
357 _FILTER_TRAINING_OUTPUT = [
358 (1, r'DEBUG:.*'),
359 (1, r'Number of weights.*'),
360 (1, r'CRF about to train.*'),
361 (1, r'CRF finished.*'),
362 (1, r'CRF training has converged.*'),
363 (2, r'CRF weights.*'),
364 (2, r'getValue\(\) \(loglikelihood\) .*'),
365 ]
366
367 @staticmethod
369 """
370 Filter the (very verbose) output that is generated by mallet,
371 and only display the interesting lines. The lines that are
372 selected for display are determined by
373 L{_FILTER_TRAINING_OUTPUT}.
374 """
375 out = []
376 while p.poll() is None:
377 while True:
378 line = p.stdout.readline()
379 if not line: break
380 out.append(line)
381 for (t, regexp) in MalletCRF._FILTER_TRAINING_OUTPUT:
382 if t <= trace and re.match(regexp, line):
383 indent = ' '*t
384 print '[MalletCRF] %s%s' % (indent, line.rstrip())
385 break
386 if p.returncode != 0:
387 print "\nError encountered! Mallet's most recent output:"
388 print ''.join(out[-100:])
389 raise OSError('Mallet command failed')
390
391
392
393
394
395
397 """
398 Write a given training corpus to a given stream, in a format that
399 can be read by the java script C{org.nltk.mallet.TrainCRF}.
400 """
401 feature_detector = self.crf_info.feature_detector
402 for sentence in corpus:
403 if self.crf_info.add_start_state:
404 stream.write('__start__ __start__\n')
405 unlabeled_sent = [tok for (tok,tag) in sentence]
406 for index in range(len(unlabeled_sent)):
407 featureset = feature_detector(unlabeled_sent, index)
408 for (fname, fval) in featureset.items():
409 stream.write(self._format_feature(fname, fval)+" ")
410 stream.write(sentence[index][1]+'\n')
411 if self.crf_info.add_end_state:
412 stream.write('__end__ __end__\n')
413 stream.write('\n')
414 if close_stream: stream.close()
415
417 """
418 Write a given test corpus to a given stream, in a format that
419 can be read by the java script C{org.nltk.mallet.TestCRF}.
420 """
421 feature_detector = self.crf_info.feature_detector
422 for sentence in corpus:
423 if self.crf_info.add_start_state:
424 stream.write('__start__ __start__\n')
425 for index in range(len(sentence)):
426 featureset = feature_detector(sentence, index)
427 for (fname, fval) in featureset.items():
428 stream.write(self._format_feature(fname, fval)+" ")
429 stream.write('\n')
430 if self.crf_info.add_end_state:
431 stream.write('__end__ __end__\n')
432 stream.write('\n')
433 if close_stream: stream.close()
434
436 """
437 Parse the output that is generated by the java script
438 C{org.nltk.mallet.TestCRF}, and convert it to a labeled
439 corpus.
440 """
441 if re.match(r'\s*<<start>>', s):
442 assert 0, 'its a lattice'
443 corpus = [[]]
444 for line in s.split('\n'):
445 line = line.strip()
446
447 if line:
448 corpus[-1].append(line.strip())
449
450 elif corpus[-1] != []:
451 corpus.append([])
452 if corpus[-1] == []: corpus.pop()
453 return corpus
454
455 _ESCAPE_RE = re.compile('[^a-zA-Z0-9]')
456 @staticmethod
458 return '%' + ('%02x' % ord(m.group()))
459
460 @staticmethod
475
476
477
478
479
481 return 'MalletCRF(%r)' % self.crf_info.model_filename
482
483
484
485
486
488 """
489 An object used to record configuration information about a
490 MalletCRF object. This configuration information can be
491 serialized to an XML file, which can then be read by NLTK's custom
492 interface to Mallet's CRF.
493
494 CRFInfo objects are typically created by the L{MalletCRF.train()}
495 method.
496
497 Advanced users may wish to directly create custom
498 C{CRFInfo.WeightGroup} objects and pass them to the
499 L{MalletCRF.train()} function. See L{CRFInfo.WeightGroup} for
500 more information.
501 """
502 - def __init__(self, states, gaussian_variance, default_label,
503 max_iterations, transduction_type, weight_groups,
504 add_start_state, add_end_state, model_filename,
505 feature_detector):
506 self.gaussian_variance = float(gaussian_variance)
507 self.default_label = default_label
508 self.states = states
509 self.max_iterations = max_iterations
510 self.transduction_type = transduction_type
511 self.weight_groups = weight_groups
512 self.add_start_state = add_start_state
513 self.add_end_state = add_end_state
514 self.model_filename = model_filename
515 if isinstance(feature_detector, basestring):
516 self.feature_detector_name = feature_detector
517 self.feature_detector = None
518 else:
519 self.feature_detector_name = feature_detector.__name__
520 self.feature_detector = feature_detector
521
522 _XML_TEMPLATE = (
523 '<crf>\n'
524 ' <modelFile>%(model_filename)s</modelFile>\n'
525 ' <gaussianVariance>%(gaussian_variance)d</gaussianVariance>\n'
526 ' <defaultLabel>%(default_label)s</defaultLabel>\n'
527 ' <maxIterations>%(max_iterations)s</maxIterations>\n'
528 ' <transductionType>%(transduction_type)s</transductionType>\n'
529 ' <featureDetector name="%(feature_detector_name)s">\n'
530 ' %(feature_detector)s\n'
531 ' </featureDetector>\n'
532 ' <addStartState>%(add_start_state)s</addStartState>\n'
533 ' <addEndState>%(add_end_state)s</addEndState>\n'
534 ' <states>\n'
535 '%(states)s\n'
536 ' </states>\n'
537 ' <weightGroups>\n'
538 '%(w_groups)s\n'
539 ' </weightGroups>\n'
540 '</crf>\n')
541
543 info = self.__dict__.copy()
544 info['states'] = '\n'.join(state.toxml() for state in self.states)
545 info['w_groups'] = '\n'.join(wg.toxml() for wg in self.weight_groups)
546 info['feature_detector_name'] = (info['feature_detector_name']
547 .replace('&', '&')
548 .replace('<', '<'))
549 try:
550 fd = pickle.dumps(self.feature_detector)
551 fd = fd.replace('&', '&').replace('<', '<')
552 fd = fd.replace('\n', ' ')
553 info['feature_detector'] = '<pickle>%s</pickle>' % fd
554 except pickle.PicklingError:
555 info['feature_detector'] = ''
556 return self._XML_TEMPLATE % info
557
558 @staticmethod
561
562 @staticmethod
564 states = [CRFInfo.State._read(et) for et in
565 etree.findall('states/state')]
566 weight_groups = [CRFInfo.WeightGroup._read(et) for et in
567 etree.findall('weightGroups/weightGroup')]
568 fd = etree.find('featureDetector')
569 feature_detector = fd.get('name')
570 if fd.find('pickle') is not None:
571 try: feature_detector = pickle.loads(fd.find('pickle').text)
572 except pickle.PicklingError, e: pass
573
574 return CRFInfo(states,
575 float(etree.find('gaussianVariance').text),
576 etree.find('defaultLabel').text,
577 int(etree.find('maxIterations').text),
578 etree.find('transductionType').text,
579 weight_groups,
580 bool(etree.find('addStartState').text),
581 bool(etree.find('addEndState').text),
582 etree.find('modelFile').text,
583 feature_detector)
584
585 - def write(self, filename):
590
592 """
593 A description of a single CRF state.
594 """
595 - def __init__(self, name, initial_cost, final_cost, transitions):
596 if initial_cost != '+inf': initial_cost = float(initial_cost)
597 if final_cost != '+inf': final_cost = float(final_cost)
598 self.name = name
599 self.initial_cost = initial_cost
600 self.final_cost = final_cost
601 self.transitions = transitions
602
603 _XML_TEMPLATE = (
604 ' <state name="%(name)s" initialCost="%(initial_cost)s" '
605 'finalCost="%(final_cost)s">\n'
606 ' <transitions>\n'
607 '%(transitions)s\n'
608 ' </transitions>\n'
609 ' </state>\n')
611 info = self.__dict__.copy()
612 info['transitions'] = '\n'.join(transition.toxml()
613 for transition in self.transitions)
614 return self._XML_TEMPLATE % info
615
616 @staticmethod
624
626 """
627 A description of a single CRF transition.
628 """
629 - def __init__(self, destination, label, weightgroups):
630 """
631 @param destination: The name of the state that this transition
632 connects to.
633 @param label: The tag that is generated when traversing this
634 transition.
635 @param weightgroups: A list of L{WeightGroup} names, indicating
636 which weight groups should be used to calculate the cost
637 of traversing this transition.
638 """
639 self.destination = destination
640 self.label = label
641 self.weightgroups = weightgroups
642
643 _XML_TEMPLATE = (' <transition label="%(label)s" '
644 'destination="%(destination)s" '
645 'weightGroups="%(w_groups)s"/>')
647 info = self.__dict__
648 info['w_groups'] = ' '.join(wg for wg in self.weightgroups)
649 return self._XML_TEMPLATE % info
650
651 @staticmethod
656
658 """
659 A configuration object used by C{MalletCRF} to specify how
660 input-features (which are a function of only the input) should be
661 mapped to joint-features (which are a function of both the input
662 and the output tags).
663
664 Each weight group specifies that a given set of input features
665 should be paired with all transitions from a given set of source
666 tags to a given set of destination tags.
667 """
668 - def __init__(self, name, src, dst, features='.*'):
669 """
670 @param name: A unique name for this weight group.
671 @param src: The set of source tags that should be used for
672 this weight group, specified as either a list of state
673 names or a regular expression.
674 @param dst: The set of destination tags that should be used
675 for this weight group, specified as either a list of state
676 names or a regular expression.
677 @param features: The set of input feature that should be used
678 for this weight group, specified as either a list of
679 feature names or a regular expression. WARNING: currently,
680 this regexp is passed streight to java -- i.e., it must
681 be a java-style regexp!
682 """
683 if re.search('\s', name):
684 raise ValueError('weight group name may not '
685 'contain whitespace.')
686 if re.search('"', name):
687 raise ValueError('weight group name may not contain \'"\'.')
688 self.name = name
689 self.src = src
690 self.dst = dst
691 self.features = features
692 self._src_match_cache = {}
693 self._dst_match_cache = {}
694
695 _XML_TEMPLATE = (' <weightGroup name="%(name)s" src="%(src)s" '
696 'dst="%(dst)s" features="%(features)s" />')
699
700 @staticmethod
706
707
708 - def match(self, src, dst):
709
710 src_match = self._src_match_cache.get(src)
711 if src_match is None:
712 if isinstance(self.src, basestring):
713 src_match = bool(re.match(self.src+'\Z', src))
714 else:
715 src_match = src in self.src
716 self._src_match_cache[src] = src_match
717
718
719 dst_match = self._dst_match_cache.get(dst)
720 if dst_match is None:
721 if isinstance(self.dst, basestring):
722 dst_match = bool(re.match(self.dst+'\Z', dst))
723 else:
724 dst_match = dst in self.dst
725 self._dst_match_cache[dst] = dst_match
726
727
728 return src_match and dst_match
729
730
731
732
733
734 -def demo(train_size=100, test_size=100,
735 java_home='/usr/local/jdk1.5.0/',
736 mallet_home='/usr/local/mallet-0.4'):
744
745
746 nltk.internals.config_java(java_home)
747 nltk.classify.mallet.config_mallet(mallet_home)
748
749
750
751 def strip(corpus): return [[(w, t[:2]) for (w,t) in sent]
752 for sent in corpus]
753 brown_train = strip(brown.tagged_sents(categories='a')[:train_size])
754 brown_test = strip(brown.tagged_sents(categories='b')[:test_size])
755
756 crf = MalletCRF.train(fd, brown_train,
757 transduction_type='VITERBI')
758 sample_output = crf.tag([w for (w,t) in brown_test[5]])
759 acc = nltk.tag.accuracy(crf, brown_test)
760 print '\nAccuracy: %.1f%%' % (acc*100)
761 print 'Sample output:'
762 print textwrap.fill(' '.join('%s/%s' % w for w in sample_output),
763 initial_indent=' ', subsequent_indent=' ')+'\n'
764
765
766 print 'Clean-up: deleting', crf.filename
767 os.remove(crf.filename)
768
769 if __name__ == '__main__':
770 demo(train_size=100)
771