Source file
src/net/rpc/server.go
Documentation: net/rpc
1
2
3
4
5
127 package rpc
128
129 import (
130 "bufio"
131 "encoding/gob"
132 "errors"
133 "io"
134 "log"
135 "net"
136 "net/http"
137 "reflect"
138 "strings"
139 "sync"
140 "unicode"
141 "unicode/utf8"
142 )
143
144 const (
145
146 DefaultRPCPath = "/_goRPC_"
147 DefaultDebugPath = "/debug/rpc"
148 )
149
150
151
152 var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
153
154 type methodType struct {
155 sync.Mutex
156 method reflect.Method
157 ArgType reflect.Type
158 ReplyType reflect.Type
159 numCalls uint
160 }
161
162 type service struct {
163 name string
164 rcvr reflect.Value
165 typ reflect.Type
166 method map[string]*methodType
167 }
168
169
170
171
172 type Request struct {
173 ServiceMethod string
174 Seq uint64
175 next *Request
176 }
177
178
179
180
181 type Response struct {
182 ServiceMethod string
183 Seq uint64
184 Error string
185 next *Response
186 }
187
188
189 type Server struct {
190 serviceMap sync.Map
191 reqLock sync.Mutex
192 freeReq *Request
193 respLock sync.Mutex
194 freeResp *Response
195 }
196
197
198 func NewServer() *Server {
199 return &Server{}
200 }
201
202
203 var DefaultServer = NewServer()
204
205
206 func isExported(name string) bool {
207 rune, _ := utf8.DecodeRuneInString(name)
208 return unicode.IsUpper(rune)
209 }
210
211
212 func isExportedOrBuiltinType(t reflect.Type) bool {
213 for t.Kind() == reflect.Ptr {
214 t = t.Elem()
215 }
216
217
218 return isExported(t.Name()) || t.PkgPath() == ""
219 }
220
221
222
223
224
225
226
227
228
229
230
231 func (server *Server) Register(rcvr interface{}) error {
232 return server.register(rcvr, "", false)
233 }
234
235
236
237 func (server *Server) RegisterName(name string, rcvr interface{}) error {
238 return server.register(rcvr, name, true)
239 }
240
241 func (server *Server) register(rcvr interface{}, name string, useName bool) error {
242 s := new(service)
243 s.typ = reflect.TypeOf(rcvr)
244 s.rcvr = reflect.ValueOf(rcvr)
245 sname := reflect.Indirect(s.rcvr).Type().Name()
246 if useName {
247 sname = name
248 }
249 if sname == "" {
250 s := "rpc.Register: no service name for type " + s.typ.String()
251 log.Print(s)
252 return errors.New(s)
253 }
254 if !isExported(sname) && !useName {
255 s := "rpc.Register: type " + sname + " is not exported"
256 log.Print(s)
257 return errors.New(s)
258 }
259 s.name = sname
260
261
262 s.method = suitableMethods(s.typ, true)
263
264 if len(s.method) == 0 {
265 str := ""
266
267
268 method := suitableMethods(reflect.PtrTo(s.typ), false)
269 if len(method) != 0 {
270 str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
271 } else {
272 str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
273 }
274 log.Print(str)
275 return errors.New(str)
276 }
277
278 if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
279 return errors.New("rpc: service already defined: " + sname)
280 }
281 return nil
282 }
283
284
285
286 func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
287 methods := make(map[string]*methodType)
288 for m := 0; m < typ.NumMethod(); m++ {
289 method := typ.Method(m)
290 mtype := method.Type
291 mname := method.Name
292
293 if method.PkgPath != "" {
294 continue
295 }
296
297 if mtype.NumIn() != 3 {
298 if reportErr {
299 log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
300 }
301 continue
302 }
303
304 argType := mtype.In(1)
305 if !isExportedOrBuiltinType(argType) {
306 if reportErr {
307 log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
308 }
309 continue
310 }
311
312 replyType := mtype.In(2)
313 if replyType.Kind() != reflect.Ptr {
314 if reportErr {
315 log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
316 }
317 continue
318 }
319
320 if !isExportedOrBuiltinType(replyType) {
321 if reportErr {
322 log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
323 }
324 continue
325 }
326
327 if mtype.NumOut() != 1 {
328 if reportErr {
329 log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
330 }
331 continue
332 }
333
334 if returnType := mtype.Out(0); returnType != typeOfError {
335 if reportErr {
336 log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
337 }
338 continue
339 }
340 methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
341 }
342 return methods
343 }
344
345
346
347
348 var invalidRequest = struct{}{}
349
350 func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
351 resp := server.getResponse()
352
353 resp.ServiceMethod = req.ServiceMethod
354 if errmsg != "" {
355 resp.Error = errmsg
356 reply = invalidRequest
357 }
358 resp.Seq = req.Seq
359 sending.Lock()
360 err := codec.WriteResponse(resp, reply)
361 if debugLog && err != nil {
362 log.Println("rpc: writing response:", err)
363 }
364 sending.Unlock()
365 server.freeResponse(resp)
366 }
367
368 func (m *methodType) NumCalls() (n uint) {
369 m.Lock()
370 n = m.numCalls
371 m.Unlock()
372 return n
373 }
374
375 func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
376 if wg != nil {
377 defer wg.Done()
378 }
379 mtype.Lock()
380 mtype.numCalls++
381 mtype.Unlock()
382 function := mtype.method.Func
383
384 returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
385
386 errInter := returnValues[0].Interface()
387 errmsg := ""
388 if errInter != nil {
389 errmsg = errInter.(error).Error()
390 }
391 server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
392 server.freeRequest(req)
393 }
394
395 type gobServerCodec struct {
396 rwc io.ReadWriteCloser
397 dec *gob.Decoder
398 enc *gob.Encoder
399 encBuf *bufio.Writer
400 closed bool
401 }
402
403 func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
404 return c.dec.Decode(r)
405 }
406
407 func (c *gobServerCodec) ReadRequestBody(body interface{}) error {
408 return c.dec.Decode(body)
409 }
410
411 func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) {
412 if err = c.enc.Encode(r); err != nil {
413 if c.encBuf.Flush() == nil {
414
415
416 log.Println("rpc: gob error encoding response:", err)
417 c.Close()
418 }
419 return
420 }
421 if err = c.enc.Encode(body); err != nil {
422 if c.encBuf.Flush() == nil {
423
424
425 log.Println("rpc: gob error encoding body:", err)
426 c.Close()
427 }
428 return
429 }
430 return c.encBuf.Flush()
431 }
432
433 func (c *gobServerCodec) Close() error {
434 if c.closed {
435
436 return nil
437 }
438 c.closed = true
439 return c.rwc.Close()
440 }
441
442
443
444
445
446
447 func (server *Server) ServeConn(conn io.ReadWriteCloser) {
448 buf := bufio.NewWriter(conn)
449 srv := &gobServerCodec{
450 rwc: conn,
451 dec: gob.NewDecoder(conn),
452 enc: gob.NewEncoder(buf),
453 encBuf: buf,
454 }
455 server.ServeCodec(srv)
456 }
457
458
459
460 func (server *Server) ServeCodec(codec ServerCodec) {
461 sending := new(sync.Mutex)
462 wg := new(sync.WaitGroup)
463 for {
464 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
465 if err != nil {
466 if debugLog && err != io.EOF {
467 log.Println("rpc:", err)
468 }
469 if !keepReading {
470 break
471 }
472
473 if req != nil {
474 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
475 server.freeRequest(req)
476 }
477 continue
478 }
479 wg.Add(1)
480 go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
481 }
482
483
484 wg.Wait()
485 codec.Close()
486 }
487
488
489
490 func (server *Server) ServeRequest(codec ServerCodec) error {
491 sending := new(sync.Mutex)
492 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
493 if err != nil {
494 if !keepReading {
495 return err
496 }
497
498 if req != nil {
499 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
500 server.freeRequest(req)
501 }
502 return err
503 }
504 service.call(server, sending, nil, mtype, req, argv, replyv, codec)
505 return nil
506 }
507
508 func (server *Server) getRequest() *Request {
509 server.reqLock.Lock()
510 req := server.freeReq
511 if req == nil {
512 req = new(Request)
513 } else {
514 server.freeReq = req.next
515 *req = Request{}
516 }
517 server.reqLock.Unlock()
518 return req
519 }
520
521 func (server *Server) freeRequest(req *Request) {
522 server.reqLock.Lock()
523 req.next = server.freeReq
524 server.freeReq = req
525 server.reqLock.Unlock()
526 }
527
528 func (server *Server) getResponse() *Response {
529 server.respLock.Lock()
530 resp := server.freeResp
531 if resp == nil {
532 resp = new(Response)
533 } else {
534 server.freeResp = resp.next
535 *resp = Response{}
536 }
537 server.respLock.Unlock()
538 return resp
539 }
540
541 func (server *Server) freeResponse(resp *Response) {
542 server.respLock.Lock()
543 resp.next = server.freeResp
544 server.freeResp = resp
545 server.respLock.Unlock()
546 }
547
548 func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
549 service, mtype, req, keepReading, err = server.readRequestHeader(codec)
550 if err != nil {
551 if !keepReading {
552 return
553 }
554
555 codec.ReadRequestBody(nil)
556 return
557 }
558
559
560 argIsValue := false
561 if mtype.ArgType.Kind() == reflect.Ptr {
562 argv = reflect.New(mtype.ArgType.Elem())
563 } else {
564 argv = reflect.New(mtype.ArgType)
565 argIsValue = true
566 }
567
568 if err = codec.ReadRequestBody(argv.Interface()); err != nil {
569 return
570 }
571 if argIsValue {
572 argv = argv.Elem()
573 }
574
575 replyv = reflect.New(mtype.ReplyType.Elem())
576
577 switch mtype.ReplyType.Elem().Kind() {
578 case reflect.Map:
579 replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
580 case reflect.Slice:
581 replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
582 }
583 return
584 }
585
586 func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
587
588 req = server.getRequest()
589 err = codec.ReadRequestHeader(req)
590 if err != nil {
591 req = nil
592 if err == io.EOF || err == io.ErrUnexpectedEOF {
593 return
594 }
595 err = errors.New("rpc: server cannot decode request: " + err.Error())
596 return
597 }
598
599
600
601 keepReading = true
602
603 dot := strings.LastIndex(req.ServiceMethod, ".")
604 if dot < 0 {
605 err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
606 return
607 }
608 serviceName := req.ServiceMethod[:dot]
609 methodName := req.ServiceMethod[dot+1:]
610
611
612 svci, ok := server.serviceMap.Load(serviceName)
613 if !ok {
614 err = errors.New("rpc: can't find service " + req.ServiceMethod)
615 return
616 }
617 svc = svci.(*service)
618 mtype = svc.method[methodName]
619 if mtype == nil {
620 err = errors.New("rpc: can't find method " + req.ServiceMethod)
621 }
622 return
623 }
624
625
626
627
628
629 func (server *Server) Accept(lis net.Listener) {
630 for {
631 conn, err := lis.Accept()
632 if err != nil {
633 log.Print("rpc.Serve: accept:", err.Error())
634 return
635 }
636 go server.ServeConn(conn)
637 }
638 }
639
640
641 func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
642
643
644
645 func RegisterName(name string, rcvr interface{}) error {
646 return DefaultServer.RegisterName(name, rcvr)
647 }
648
649
650
651
652
653
654
655
656 type ServerCodec interface {
657 ReadRequestHeader(*Request) error
658 ReadRequestBody(interface{}) error
659
660 WriteResponse(*Response, interface{}) error
661
662 Close() error
663 }
664
665
666
667
668
669
670 func ServeConn(conn io.ReadWriteCloser) {
671 DefaultServer.ServeConn(conn)
672 }
673
674
675
676 func ServeCodec(codec ServerCodec) {
677 DefaultServer.ServeCodec(codec)
678 }
679
680
681
682 func ServeRequest(codec ServerCodec) error {
683 return DefaultServer.ServeRequest(codec)
684 }
685
686
687
688
689 func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
690
691
692 var connected = "200 Connected to Go RPC"
693
694
695 func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
696 if req.Method != "CONNECT" {
697 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
698 w.WriteHeader(http.StatusMethodNotAllowed)
699 io.WriteString(w, "405 must CONNECT\n")
700 return
701 }
702 conn, _, err := w.(http.Hijacker).Hijack()
703 if err != nil {
704 log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
705 return
706 }
707 io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
708 server.ServeConn(conn)
709 }
710
711
712
713
714 func (server *Server) HandleHTTP(rpcPath, debugPath string) {
715 http.Handle(rpcPath, server)
716 http.Handle(debugPath, debugHTTP{server})
717 }
718
719
720
721
722 func HandleHTTP() {
723 DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
724 }
725
View as plain text