mirror of
				https://github.com/containers/podman.git
				synced 2025-10-31 18:08:51 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			414 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			414 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2018 The go-libvirt Authors.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //   http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| package libvirt
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 	"sync/atomic"
 | |
| 
 | |
| 	"github.com/digitalocean/go-libvirt/internal/constants"
 | |
| 	"github.com/digitalocean/go-libvirt/internal/event"
 | |
| 	xdr "github.com/digitalocean/go-libvirt/internal/go-xdr/xdr2"
 | |
| 	"github.com/digitalocean/go-libvirt/socket"
 | |
| )
 | |
| 
 | |
| // ErrUnsupported is returned if a procedure is not supported by libvirt
 | |
| var ErrUnsupported = errors.New("unsupported procedure requested")
 | |
| 
 | |
| // internal rpc response
 | |
| type response struct {
 | |
| 	Payload []byte
 | |
| 	Status  uint32
 | |
| }
 | |
| 
 | |
| // Error reponse from libvirt
 | |
| type Error struct {
 | |
| 	Code    uint32
 | |
| 	Message string
 | |
| }
 | |
| 
 | |
| func (e Error) Error() string {
 | |
| 	return e.Message
 | |
| }
 | |
| 
 | |
| // checkError is used to check whether an error is a libvirtError, and if it is,
 | |
| // whether its error code matches the one passed in. It will return false if
 | |
| // these conditions are not met.
 | |
| func checkError(err error, expectedError ErrorNumber) bool {
 | |
| 	for err != nil {
 | |
| 		e, ok := err.(Error)
 | |
| 		if ok {
 | |
| 			return e.Code == uint32(expectedError)
 | |
| 		}
 | |
| 		err = errors.Unwrap(err)
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // IsNotFound detects libvirt's ERR_NO_DOMAIN.
 | |
| func IsNotFound(err error) bool {
 | |
| 	return checkError(err, ErrNoDomain)
 | |
| }
 | |
| 
 | |
| // callback sends RPC responses to respective callers.
 | |
| func (l *Libvirt) callback(id int32, res response) {
 | |
| 	l.cmux.Lock()
 | |
| 	defer l.cmux.Unlock()
 | |
| 
 | |
| 	c, ok := l.callbacks[id]
 | |
| 	if !ok {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	c <- res
 | |
| }
 | |
| 
 | |
| // Route sends incoming packets to their listeners.
 | |
| func (l *Libvirt) Route(h *socket.Header, buf []byte) {
 | |
| 	// Route events to their respective listener
 | |
| 	var event event.Event
 | |
| 
 | |
| 	switch {
 | |
| 	case h.Program == constants.QEMUProgram && h.Procedure == constants.QEMUProcDomainMonitorEvent:
 | |
| 		event = &DomainEvent{}
 | |
| 	case h.Program == constants.Program && h.Procedure == constants.ProcDomainEventCallbackLifecycle:
 | |
| 		event = &DomainEventCallbackLifecycleMsg{}
 | |
| 	}
 | |
| 
 | |
| 	if event != nil {
 | |
| 		err := eventDecoder(buf, event)
 | |
| 		if err != nil { // event was malformed, drop.
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		l.stream(event)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// send response to caller
 | |
| 	l.callback(h.Serial, response{Payload: buf, Status: h.Status})
 | |
| }
 | |
| 
 | |
| // serial provides atomic access to the next sequential request serial number.
 | |
| func (l *Libvirt) serial() int32 {
 | |
| 	return atomic.AddInt32(&l.s, 1)
 | |
| }
 | |
| 
 | |
| // stream decodes and relays domain events to their respective listener.
 | |
| func (l *Libvirt) stream(e event.Event) {
 | |
| 	l.emux.RLock()
 | |
| 	defer l.emux.RUnlock()
 | |
| 
 | |
| 	q, ok := l.events[e.GetCallbackID()]
 | |
| 	if !ok {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	q.Push(e)
 | |
| }
 | |
| 
 | |
| // addStream configures the routing for an event stream.
 | |
| func (l *Libvirt) addStream(s *event.Stream) {
 | |
| 	l.emux.Lock()
 | |
| 	defer l.emux.Unlock()
 | |
| 
 | |
| 	l.events[s.CallbackID] = s
 | |
| }
 | |
| 
 | |
| // removeStream deletes an event stream. The caller should first notify libvirt
 | |
| // to stop sending events for this stream. Subsequent calls to removeStream are
 | |
| // idempotent and return nil.
 | |
| func (l *Libvirt) removeStream(id int32) error {
 | |
| 	l.emux.Lock()
 | |
| 	defer l.emux.Unlock()
 | |
| 
 | |
| 	// if the event is already removed, just return nil
 | |
| 	q, ok := l.events[id]
 | |
| 	if ok {
 | |
| 		delete(l.events, id)
 | |
| 		q.Shutdown()
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // removeAllStreams deletes all event streams.  This is meant to be used to
 | |
| // clean up only once the underlying connection to libvirt is disconnected and
 | |
| // thus does not attempt to notify libvirt to stop sending events.
 | |
| func (l *Libvirt) removeAllStreams() {
 | |
| 	l.emux.Lock()
 | |
| 	defer l.emux.Unlock()
 | |
| 
 | |
| 	for _, ev := range l.events {
 | |
| 		ev.Shutdown()
 | |
| 		delete(l.events, ev.CallbackID)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // register configures a method response callback
 | |
| func (l *Libvirt) register(id int32, c chan response) {
 | |
| 	l.cmux.Lock()
 | |
| 	defer l.cmux.Unlock()
 | |
| 
 | |
| 	l.callbacks[id] = c
 | |
| }
 | |
| 
 | |
| // deregister destroys a method response callback. It is the responsibility of
 | |
| // the caller to manage locking (l.cmux) during this call.
 | |
| func (l *Libvirt) deregister(id int32) {
 | |
| 	_, ok := l.callbacks[id]
 | |
| 	if !ok {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	close(l.callbacks[id])
 | |
| 	delete(l.callbacks, id)
 | |
| }
 | |
| 
 | |
| // deregisterAll closes all waiting callback channels. This is used to clean up
 | |
| // if the connection to libvirt is lost. Callers waiting for responses will
 | |
| // return an error when the response channel is closed, rather than just
 | |
| // hanging.
 | |
| func (l *Libvirt) deregisterAll() {
 | |
| 	l.cmux.Lock()
 | |
| 	defer l.cmux.Unlock()
 | |
| 
 | |
| 	for id := range l.callbacks {
 | |
| 		l.deregister(id)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // request performs a libvirt RPC request.
 | |
| // returns response returned by server.
 | |
| // if response is not OK, decodes error from it and returns it.
 | |
| func (l *Libvirt) request(proc uint32, program uint32, payload []byte) (response, error) {
 | |
| 	return l.requestStream(proc, program, payload, nil, nil)
 | |
| }
 | |
| 
 | |
| // requestStream performs a libvirt RPC request. The `out` and `in` parameters
 | |
| // are optional, and should be nil when RPC endpoints don't return a stream.
 | |
| func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte,
 | |
| 	out io.Reader, in io.Writer) (response, error) {
 | |
| 	serial := l.serial()
 | |
| 	c := make(chan response)
 | |
| 
 | |
| 	l.register(serial, c)
 | |
| 	defer func() {
 | |
| 		l.cmux.Lock()
 | |
| 		defer l.cmux.Unlock()
 | |
| 
 | |
| 		l.deregister(serial)
 | |
| 	}()
 | |
| 
 | |
| 	err := l.socket.SendPacket(serial, proc, program, payload, socket.Call,
 | |
| 		socket.StatusOK)
 | |
| 	if err != nil {
 | |
| 		return response{}, err
 | |
| 	}
 | |
| 
 | |
| 	resp, err := l.getResponse(c)
 | |
| 	if err != nil {
 | |
| 		return resp, err
 | |
| 	}
 | |
| 
 | |
| 	if out != nil {
 | |
| 		abort := make(chan bool)
 | |
| 		outErr := make(chan error)
 | |
| 		go func() {
 | |
| 			outErr <- l.socket.SendStream(serial, proc, program, out, abort)
 | |
| 		}()
 | |
| 
 | |
| 		// Even without incoming stream server sends confirmation once all data is received
 | |
| 		resp, err = l.processIncomingStream(c, in)
 | |
| 		if err != nil {
 | |
| 			abort <- true
 | |
| 			return resp, err
 | |
| 		}
 | |
| 
 | |
| 		err = <-outErr
 | |
| 		if err != nil {
 | |
| 			return response{}, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	switch in {
 | |
| 	case nil:
 | |
| 		return resp, nil
 | |
| 	default:
 | |
| 		return l.processIncomingStream(c, in)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // processIncomingStream is called once we've successfully sent a request to
 | |
| // libvirt. It writes the responses back to the stream passed by the caller
 | |
| // until libvirt sends a packet with statusOK or an error.
 | |
| func (l *Libvirt) processIncomingStream(c chan response, inStream io.Writer) (response, error) {
 | |
| 	for {
 | |
| 		resp, err := l.getResponse(c)
 | |
| 		if err != nil {
 | |
| 			return resp, err
 | |
| 		}
 | |
| 
 | |
| 		// StatusOK indicates end of stream
 | |
| 		if resp.Status == socket.StatusOK {
 | |
| 			return resp, nil
 | |
| 		}
 | |
| 
 | |
| 		// FIXME: this smells.
 | |
| 		// StatusError is handled in getResponse, so this must be StatusContinue
 | |
| 		// StatusContinue is only valid here for stream packets
 | |
| 		// libvirtd breaks protocol and returns StatusContinue with an
 | |
| 		// empty response Payload when the stream finishes
 | |
| 		if len(resp.Payload) == 0 {
 | |
| 			return resp, nil
 | |
| 		}
 | |
| 		if inStream != nil {
 | |
| 			_, err = inStream.Write(resp.Payload)
 | |
| 			if err != nil {
 | |
| 				return response{}, err
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (l *Libvirt) getResponse(c chan response) (response, error) {
 | |
| 	resp := <-c
 | |
| 	if resp.Status == socket.StatusError {
 | |
| 		return resp, decodeError(resp.Payload)
 | |
| 	}
 | |
| 
 | |
| 	return resp, nil
 | |
| }
 | |
| 
 | |
| // encode XDR encodes the provided data.
 | |
| func encode(data interface{}) ([]byte, error) {
 | |
| 	var buf bytes.Buffer
 | |
| 	_, err := xdr.Marshal(&buf, data)
 | |
| 
 | |
| 	return buf.Bytes(), err
 | |
| }
 | |
| 
 | |
| // decodeError extracts an error message from the provider buffer.
 | |
| func decodeError(buf []byte) error {
 | |
| 	dec := xdr.NewDecoder(bytes.NewReader(buf))
 | |
| 
 | |
| 	e := struct {
 | |
| 		Code     uint32
 | |
| 		DomainID uint32
 | |
| 		Padding  uint8
 | |
| 		Message  string
 | |
| 		Level    uint32
 | |
| 	}{}
 | |
| 	_, err := dec.Decode(&e)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if strings.Contains(e.Message, "unknown procedure") {
 | |
| 		return ErrUnsupported
 | |
| 	}
 | |
| 
 | |
| 	// if libvirt returns ERR_OK, ignore the error
 | |
| 	if ErrorNumber(e.Code) == ErrOk {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	return Error{Code: uint32(e.Code), Message: e.Message}
 | |
| }
 | |
| 
 | |
| // eventDecoder decodes an event from a xdr buffer.
 | |
| func eventDecoder(buf []byte, e interface{}) error {
 | |
| 	dec := xdr.NewDecoder(bytes.NewReader(buf))
 | |
| 	_, err := dec.Decode(e)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| type typedParamDecoder struct{}
 | |
| 
 | |
| // Decode decodes a TypedParam. These are part of the libvirt spec, and not xdr
 | |
| // proper. TypedParams contain a name, which is called Field for some reason,
 | |
| // and a Value, which itself has a "discriminant" - an integer enum encoding the
 | |
| // actual type, and a value, the length of which varies based on the actual
 | |
| // type.
 | |
| func (tpd typedParamDecoder) Decode(d *xdr.Decoder, v reflect.Value) (int, error) {
 | |
| 	// Get the name of the typed param first
 | |
| 	name, n, err := d.DecodeString()
 | |
| 	if err != nil {
 | |
| 		return n, err
 | |
| 	}
 | |
| 	val, n2, err := tpd.decodeTypedParamValue(d)
 | |
| 	n += n2
 | |
| 	if err != nil {
 | |
| 		return n, err
 | |
| 	}
 | |
| 	tp := &TypedParam{Field: name, Value: *val}
 | |
| 	v.Set(reflect.ValueOf(*tp))
 | |
| 
 | |
| 	return n, nil
 | |
| }
 | |
| 
 | |
| // decodeTypedParamValue decodes the Value part of a TypedParam.
 | |
| func (typedParamDecoder) decodeTypedParamValue(d *xdr.Decoder) (*TypedParamValue, int, error) {
 | |
| 	// All TypedParamValues begin with a uint32 discriminant that tells us what
 | |
| 	// type they are.
 | |
| 	discriminant, n, err := d.DecodeUint()
 | |
| 	if err != nil {
 | |
| 		return nil, n, err
 | |
| 	}
 | |
| 	var n2 int
 | |
| 	var tpv *TypedParamValue
 | |
| 	switch discriminant {
 | |
| 	case 1:
 | |
| 		var val int32
 | |
| 		n2, err = d.Decode(&val)
 | |
| 		tpv = &TypedParamValue{D: discriminant, I: val}
 | |
| 	case 2:
 | |
| 		var val uint32
 | |
| 		n2, err = d.Decode(&val)
 | |
| 		tpv = &TypedParamValue{D: discriminant, I: val}
 | |
| 	case 3:
 | |
| 		var val int64
 | |
| 		n2, err = d.Decode(&val)
 | |
| 		tpv = &TypedParamValue{D: discriminant, I: val}
 | |
| 	case 4:
 | |
| 		var val uint64
 | |
| 		n2, err = d.Decode(&val)
 | |
| 		tpv = &TypedParamValue{D: discriminant, I: val}
 | |
| 	case 5:
 | |
| 		var val float64
 | |
| 		n2, err = d.Decode(&val)
 | |
| 		tpv = &TypedParamValue{D: discriminant, I: val}
 | |
| 	case 6:
 | |
| 		var val int32
 | |
| 		n2, err = d.Decode(&val)
 | |
| 		tpv = &TypedParamValue{D: discriminant, I: val}
 | |
| 	case 7:
 | |
| 		var val string
 | |
| 		n2, err = d.Decode(&val)
 | |
| 		tpv = &TypedParamValue{D: discriminant, I: val}
 | |
| 
 | |
| 	default:
 | |
| 		err = fmt.Errorf("invalid parameter type %v", discriminant)
 | |
| 	}
 | |
| 	n += n2
 | |
| 
 | |
| 	return tpv, n, err
 | |
| }
 | 
![renovate[bot]](/assets/img/avatar_default.png)