Commit 1f57d8d4 authored by Chris Jones's avatar Chris Jones
Browse files

Bug 564086, part q: Generate C++ goop for creating |bridge| channels. r=bent

parent c29370b3
Loading
Loading
Loading
Loading
+176 −19
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ from copy import deepcopy

import ipdl.ast
from ipdl.cxx.ast import *
from ipdl.type import TypeVisitor
from ipdl.type import ActorType, ProcessGraph, TypeVisitor

# FIXME/cjones: the chromium Message logging code doesn't work on
# gcc/POSIX, because it wprintf()s across the chromium/mozilla
@@ -142,6 +142,9 @@ def _actorManager(actor):
def _actorState(actor):
    return ExprSelect(actor, '->', 'mState')

def _backstagePass():
    return ExprCall(ExprVar('mozilla::ipc::PrivateIPDLInterface'))

def _nullState(proto=None):
    pfx = ''
    if proto is not None:  pfx = proto.name() +'::'
@@ -1060,9 +1063,21 @@ class Protocol(ipdl.ast.Protocol):
    def otherProcessMethod(self):
        return ExprVar('OtherProcess')

    def callOtherProcess(self, actorThis=None):
        fn = self.otherProcessMethod()
        if actorThis is not None:
            fn = ExprSelect(actorThis, '->', fn.name)
        return ExprCall(fn)

    def getChannelMethod(self):
        return ExprVar('GetIPCChannel')

    def callGetChannel(self, actorThis=None):
        fn = self.getChannelMethod()
        if actorThis is not None:
            fn = ExprSelect(actorThis, '->', fn.name)
        return ExprCall(fn)

    def processingErrorVar(self):
        assert self.decl.type.isToplevel()
        return ExprVar('ProcessingError')
@@ -1328,6 +1343,7 @@ class _GenerateProtocolCode(ipdl.ast.Visitor):
        self.protocol = None     # protocol we're generating a class for
        self.hdrfile = None      # what will become Protocol.h
        self.cppfile = None      # what will become Protocol.cpp
        self.cppIncludeHeaders = []
        self.structUnionDefns = []
        self.funcDefns = []

@@ -1350,14 +1366,12 @@ class _GenerateProtocolCode(ipdl.ast.Visitor):
        hf.addthings(_includeGuardEnd(hf))

        cf = self.cppfile
        cf.addthings([
            _DISCLAIMER,
            Whitespace.NL,
            CppDirective(
                'include',
                '"'+ _protocolHeaderName(self.protocol, '') +'.h"'),
            Whitespace.NL
        ])
        cf.addthings((
            [ _DISCLAIMER, Whitespace.NL ]
            + [ CppDirective('include','"'+h+'.h"')
                for h in self.cppIncludeHeaders ]
            + [ Whitespace.NL ]
        ))
       
        # construct the namespace into which we'll stick all our defns
        ns = Namespace(self.protocol.name)
@@ -1403,6 +1417,19 @@ class _GenerateProtocolCode(ipdl.ast.Visitor):
                                              *_generateCxxUnion(ud))

    def visitProtocol(self, p):
        self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, ''))
        bridges = ProcessGraph.bridgesOf(p.decl.type)
        for bridge in bridges:
            ppt, pside = bridge.parent.ptype, _otherSide(bridge.parent.side)
            cpt, cside = bridge.child.ptype, _otherSide(bridge.child.side)
            self.hdrfile.addthings([
                Whitespace.NL,
                _makeForwardDeclForActor(ppt, pside),
                _makeForwardDeclForActor(cpt, cside)
            ])
            self.cppIncludeHeaders.append(_protocolHeaderName(ppt._p, pside))
            self.cppIncludeHeaders.append(_protocolHeaderName(cpt._p, cside))

        self.hdrfile.addthing(Whitespace("""
//-----------------------------------------------------------------------------
// Code common to %sChild and %sParent
@@ -1414,6 +1441,12 @@ class _GenerateProtocolCode(ipdl.ast.Visitor):
        self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces))
        ns.addstmt(Whitespace.NL)

        # user-facing methods for connecting two process with a new channel
        for bridge in bridges:
            bdecl, bdefn = _splitFuncDeclDefn(self.genBridgeFunc(bridge))
            ns.addstmts([ bdecl, Whitespace.NL ])
            self.funcDefns.append(bdefn)

        # state information
        stateenum = TypeEnum('State')
        # NB: __Dead is the first state on purpose, so that it has
@@ -1465,6 +1498,31 @@ class _GenerateProtocolCode(ipdl.ast.Visitor):
        ns.addstmts([ Whitespace.NL, Whitespace.NL ])


    def genBridgeFunc(self, bridge):
        p = self.protocol
        parentHandleType = _cxxBareType(ActorType(bridge.parent.ptype),
                                        _otherSide(bridge.parent.side))
        parentvar = ExprVar('parentHandle')

        childHandleType = _cxxBareType(ActorType(bridge.child.ptype),
                                       _otherSide(bridge.child.side))
        childvar = ExprVar('childHandle')

        bridgefunc = MethodDefn(MethodDecl(
            'Bridge',
            params=[ Decl(parentHandleType, parentvar.name),
                     Decl(childHandleType, childvar.name) ],
            ret=Type.BOOL))
        bridgefunc.addstmt(StmtReturn(ExprCall(
            ExprVar('mozilla::ipc::Bridge'),
            args=[ _backstagePass(),
                   p.callGetChannel(parentvar), p.callOtherProcess(parentvar),
                   p.callGetChannel(childvar), p.callOtherProcess(childvar),
                   _protocolId(p.decl.type)
                   ])))
        return bridgefunc


    def genTransitionFunc(self):
        ptype = self.protocol.decl.type
        usesend, sendvar = set(), ExprVar('__Send')
@@ -2416,6 +2474,8 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
                       Inherit(p.managerInterfaceType(), viz='protected') ],
            abstract=True)

        bridgeActorsCreated = ProcessGraph.bridgeEndpointsOf(ptype, self.side)

        friends = _FindFriends().findFriends(ptype)
        if ptype.isManaged():
            friends.update(ptype.managers)
@@ -2437,11 +2497,27 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
                                           self.prettyside)),
                Whitespace.NL ])

        for actor in bridgeActorsCreated:
            self.hdrfile.addthings([
                Whitespace.NL,
                _makeForwardDeclForActor(actor.ptype, actor.side),
                Whitespace.NL
            ])

        self.cls.addstmt(Label.PROTECTED)
        for typedef in p.cxxTypedefs():
            self.cls.addstmt(typedef)
        for typedef in self.includedActorTypedefs:
            self.cls.addstmt(typedef)
        # XXX these don't really fit in the other lists; just include
        # them here for now
        self.cls.addstmts([
            Typedef(Type('base::ProcessId'), 'ProcessId'),
            Typedef(Type('mozilla::ipc::ProtocolId'), 'ProtocolId'),
            Typedef(Type('mozilla::ipc::Transport'), 'Transport'),
            Typedef(Type('mozilla::ipc::TransportDescriptor'), 'TransportDescriptor')
        ])

        self.cls.addstmt(Whitespace.NL)

        self.cls.addstmts([ Typedef(p.fqStateType(), 'State'), Whitespace.NL ])
@@ -2487,6 +2563,17 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
                ret=Type.BOOL,
                virtual=1, pure=1)))

        for actor in bridgeActorsCreated:
            # add the Alloc interface for actors created when this
            # protocol is bridged to another
            actortype = _cxxBareType(actor.asType(), actor.side)
            self.cls.addstmt(StmtDecl(MethodDecl(
                _allocMethod(actor.ptype).name,
                params=[ Decl(Type('Transport', ptr=1), 'transport'),
                         Decl(Type('ProcessId'), 'otherProcess') ],
                ret=actortype,
                virtual=1, pure=1)))

        # optional ActorDestroy() method; default is no-op
        self.cls.addstmts([
            Whitespace.NL,
@@ -2571,6 +2658,7 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
            aTransportVar = ExprVar('aTransport')
            aThreadVar = ExprVar('aThread')
            processvar = ExprVar('aOtherProcess')
            sidevar = ExprVar('aSide')
            openmeth = MethodDefn(
                MethodDecl(
                    'Open',
@@ -2579,13 +2667,16 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
                             Decl(Type('ProcessHandle'), processvar.name),
                             Param(Type('MessageLoop', ptr=True),
                                   aThreadVar.name,
                                   default=ExprLiteral.NULL) ],
                                   default=ExprLiteral.NULL),
                             Param(Type('AsyncChannel::Side'),
                                   sidevar.name,
                                   default=ExprVar('Channel::Unknown')) ],
                    ret=Type.BOOL))

            openmeth.addstmts([
                StmtExpr(ExprAssn(p.otherProcessVar(), processvar)),
                StmtReturn(ExprCall(ExprSelect(p.channelVar(), '.', 'Open'),
                                    [ aTransportVar, aThreadVar ]))
                                    [ aTransportVar, aThreadVar, sidevar ]))
            ])
            self.cls.addstmts([
                openmeth,
@@ -2676,6 +2767,10 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
        for md in p.messageDecls:
            self.visitMessageDecl(md)

        # Handlers for the creation of "bridge" actors
        if len(bridgeActorsCreated):
            self.makeBridgeHandlers(bridgeActorsCreated)

        # add default cases
        default = StmtBlock()
        default.addstmt(StmtReturn(_Result.NotKnown))
@@ -2959,8 +3054,7 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
                Whitespace.NL
            ])

            ifkill = StmtIf(ExprNot(
                _killProcess(ExprCall(p.otherProcessMethod()))))
            ifkill = StmtIf(ExprNot(_killProcess(p.callOtherProcess())))
            ifkill.addifstmt(
                _printErrorMessage("  may have failed to kill child!"))
            fatalerror.addstmt(ifkill)
@@ -3254,7 +3348,7 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
                               p.nextShmemIdExpr(self.side) ]),
                StmtDecl(Decl(Type('Message', ptr=1), descriptorvar.name),
                         init=_shmemShareTo(shmemvar,
                                            ExprCall(p.otherProcessMethod()),
                                            p.callOtherProcess(),
                                            p.routingId()))
            ])
            failif = StmtIf(ExprNot(descriptorvar))
@@ -3295,7 +3389,7 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
                               p.nextShmemIdExpr(self.side) ]),
                StmtDecl(Decl(Type('Message', ptr=1), descriptorvar.name),
                         init=_shmemShareTo(shmemvar,
                                            ExprCall(p.otherProcessMethod()),
                                            p.callOtherProcess(),
                                            p.routingId()))
            ])
            failif = StmtIf(ExprNot(descriptorvar))
@@ -3349,7 +3443,7 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
                StmtDecl(Decl(Type('Message', ptr=1), descriptorvar.name),
                         init=_shmemUnshareFrom(
                             shmemvar,
                             ExprCall(p.otherProcessMethod()),
                             p.callOtherProcess(),
                             p.routingId())),
                Whitespace.NL,
                StmtExpr(p.removeShmemId(idvar)),
@@ -3415,9 +3509,8 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
            destroyshmem.addstmt(StmtReturn(ExprCall(
                ExprSelect(p.managerVar(), '->', p.destroySharedMemory().name),
                [ shmemvar ])))
            otherprocess.addstmt(StmtReturn(ExprCall(
                ExprSelect(p.managerVar(), '->',
                           p.otherProcessMethod().name))))
            otherprocess.addstmt(StmtReturn(
                p.callOtherProcess(p.managerVar())))
            getchannel.addstmt(StmtReturn(p.channelVar()))

        # all protocols share the "same" RemoveManagee() implementation
@@ -3662,6 +3755,70 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
        return case


    def makeBridgeHandlers(self, bridgeActors):
        handlers = StmtBlock()

        # unpack the transport descriptor et al.
        msgvar = self.msgvar
        tdvar = ExprVar('td')
        pidvar = ExprVar('pid')
        pvar = ExprVar('p')
        iffail = StmtIf(ExprNot(ExprCall(
            ExprVar('mozilla::ipc::UnpackChannelOpened'),
            args=[ _backstagePass(),
                   msgvar,
                   ExprAddrOf(tdvar), ExprAddrOf(pidvar), ExprAddrOf(pvar) ])))
        iffail.addifstmt(StmtReturn(_Result.PayloadError))
        handlers.addstmts([
            StmtDecl(Decl(Type('TransportDescriptor'), tdvar.name)),
            StmtDecl(Decl(Type('ProcessId'), pidvar.name)),
            StmtDecl(Decl(Type('ProtocolId'), pvar.name)),
            iffail,
            Whitespace.NL
        ])

        def makeHandlerCase(actor):
            case = StmtBlock()
            if actor.side is 'parent':  mode = 'SERVER'
            elif actor.side is 'child': mode = 'CLIENT'
            modevar = ExprVar('Transport::MODE_'+ mode)
            tvar = ExprVar('t')
            iffailopen = StmtIf(ExprNot(ExprAssn(
                tvar,
                ExprCall(ExprVar('mozilla::ipc::OpenDescriptor'),
                         args=[ tdvar, modevar ]))))
            iffailopen.addifstmt(StmtReturn(_Result.ValuError))

            iffailalloc = StmtIf(ExprNot(ExprCall(
                _allocMethod(actor.ptype),
                args=[ tvar, pidvar ])))
            iffailalloc.addifstmt(StmtReturn(_Result.ProcessingError))

            case.addstmts([
                StmtDecl(Decl(Type('Transport', ptr=1), tvar.name)),
                iffailopen,
                iffailalloc,
                StmtBreak()
            ])
            return CaseLabel(_protocolId(actor.ptype).name), case

        pswitch = StmtSwitch(pvar)
        for actor in bridgeActors:
            label, case = makeHandlerCase(actor)
            pswitch.addcase(label, case)

        die = Block()
        die.addstmts([ _runtimeAbort('Invalid protocol'),
                       StmtReturn(_Result.ValuError) ])
        pswitch.addcase(DefaultLabel(), die)

        handlers.addstmts([
            pswitch,
            StmtReturn(_Result.Processed)
        ])
        self.asyncSwitch.addcase(CaseLabel('CHANNEL_OPENED_MESSAGE_TYPE'),
                                 handlers)

    ##-------------------------------------------------------------------------
    ## The next few functions are the crux of the IPDL code generator.
    ## They generate code for all the nasty work of message