Skip to content
34 changes: 0 additions & 34 deletions htlcswitch/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -2605,40 +2605,6 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
// forwarded.
availableBandwidth := l.Bandwidth()

auxBandwidth, externalErr := fn.MapOptionZ(
l.cfg.AuxTrafficShaper,
func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] {
var htlcBlob fn.Option[tlv.Blob]
blob, err := customRecords.Serialize()
if err != nil {
return fn.Err[OptionalBandwidth](
fmt.Errorf("unable to serialize "+
"custom records: %w", err))
}

if len(blob) > 0 {
htlcBlob = fn.Some(blob)
}

return l.AuxBandwidth(amt, originalScid, htlcBlob, ts)
},
).Unpack()
if externalErr != nil {
l.log.Errorf("Unable to determine aux bandwidth: %v",
externalErr)

return NewLinkError(&lnwire.FailTemporaryNodeFailure{})
}

if auxBandwidth.IsHandled && auxBandwidth.Bandwidth.IsSome() {
auxBandwidth.Bandwidth.WhenSome(
func(bandwidth lnwire.MilliSatoshi) {
availableBandwidth = bandwidth
},
)
}

// Check to see if there is enough balance in this channel.
if amt > availableBandwidth {
l.log.Warnf("insufficient bandwidth to route htlc: %v is "+
"larger than %v", amt, availableBandwidth)
Expand Down
39 changes: 39 additions & 0 deletions lnwallet/aux_test_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package lnwallet

import (
"github.com/lightningnetwork/lnd/lnwire"
)

// NewTestAuxHtlcDescriptor creates an AuxHtlcDescriptor for testing purposes.
// This function allows tests to create descriptors with specific commit heights
// and entry types, which are normally unexported fields.
func NewTestAuxHtlcDescriptor(
chanID lnwire.ChannelID,
rHash PaymentHash,
timeout uint32,
amount lnwire.MilliSatoshi,
htlcIndex uint64,
parentIndex uint64,
entryType uint8,
customRecords lnwire.CustomRecords,
addHeightLocal uint64,
addHeightRemote uint64,
removeHeightLocal uint64,
removeHeightRemote uint64,
) AuxHtlcDescriptor {

return AuxHtlcDescriptor{
ChanID: chanID,
RHash: rHash,
Timeout: timeout,
Amount: amount,
HtlcIndex: htlcIndex,
ParentIndex: parentIndex,
EntryType: updateType(entryType),
CustomRecords: customRecords,
addCommitHeightLocal: addHeightLocal,
addCommitHeightRemote: addHeightRemote,
removeCommitHeightLocal: removeHeightLocal,
removeCommitHeightRemote: removeHeightRemote,
}
}
93 changes: 87 additions & 6 deletions lnwallet/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,14 @@ type LightningChannel struct {
// is created.
type ChannelOpt func(*channelOpts)

// AuxHtlcValidator is a function that validates whether an HTLC can be added
// to a custom channel. It is called during HTLC validation with the current
// channel state and HTLC details. This allows external components (like the
// traffic shaper) to perform final validation checks against the most
// up-to-date channel state before the HTLC is committed.
type AuxHtlcValidator func(amount, linkBandwidth lnwire.MilliSatoshi,
customRecords lnwire.CustomRecords, view AuxHtlcView) error

// channelOpts is the set of options used to create a new channel.
type channelOpts struct {
localNonce *musig2.Nonces
Expand All @@ -842,6 +850,10 @@ type channelOpts struct {
auxSigner fn.Option[AuxSigner]
auxResolver fn.Option[AuxContractResolver]

// auxHtlcValidator is an optional validator that performs custom
// validation on HTLCs before they are added to the channel state.
auxHtlcValidator fn.Option[AuxHtlcValidator]

skipNonceInit bool
}

Expand Down Expand Up @@ -894,6 +906,15 @@ func WithAuxResolver(resolver AuxContractResolver) ChannelOpt {
}
}

// WithAuxHtlcValidator is used to specify a custom HTLC validator for the
// channel. This validator will be called during HTLC addition to perform
// final validation checks against the most up-to-date channel state.
func WithAuxHtlcValidator(validator AuxHtlcValidator) ChannelOpt {
return func(o *channelOpts) {
o.auxHtlcValidator = fn.Some(validator)
}
}

// defaultChannelOpts returns the set of default options for a new channel.
func defaultChannelOpts() *channelOpts {
return &channelOpts{}
Expand Down Expand Up @@ -2738,9 +2759,15 @@ func (lc *LightningChannel) FetchLatestAuxHTLCView() AuxHtlcView {
lc.RLock()
defer lc.RUnlock()

return newAuxHtlcView(lc.fetchHTLCView(
lc.updateLogs.Remote.logIndex, lc.updateLogs.Local.logIndex,
))
nextHeight := lc.commitChains.Local.tip().height + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this nextHeight needed ?

remoteACKedIndex := lc.commitChains.Local.tail().messageIndices.Remote
view := lc.fetchHTLCView(
remoteACKedIndex, lc.updateLogs.Local.logIndex,
)

view.NextHeight = nextHeight

return newAuxHtlcView(view)
}

// fetchHTLCView returns all the candidate HTLC updates which should be
Expand Down Expand Up @@ -6061,7 +6088,7 @@ func (lc *LightningChannel) addHTLC(htlc *lnwire.UpdateAddHTLC,
defer lc.Unlock()

pd := lc.htlcAddDescriptor(htlc, openKey)
if err := lc.validateAddHtlc(pd, buffer); err != nil {
if err := lc.validateAddHtlc(pd, buffer, true); err != nil {
return 0, err
}

Expand Down Expand Up @@ -6179,7 +6206,7 @@ func (lc *LightningChannel) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error {

// Enforce the FeeBuffer because we are evaluating whether we can add
// another htlc to the channel state.
if err := lc.validateAddHtlc(pd, FeeBuffer); err != nil {
if err := lc.validateAddHtlc(pd, FeeBuffer, false); err != nil {
lc.log.Debugf("May add outgoing htlc rejected: %v", err)
return err
}
Expand Down Expand Up @@ -6215,7 +6242,8 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC,
// validateAddHtlc validates the addition of an outgoing htlc to our local and
// remote commitments.
func (lc *LightningChannel) validateAddHtlc(pd *paymentDescriptor,
buffer BufferType) error {
buffer BufferType, finalCheck bool) error {

// Make sure adding this HTLC won't violate any of the constraints we
// must keep on the commitment transactions.
remoteACKedIndex := lc.commitChains.Local.tail().messageIndices.Remote
Expand Down Expand Up @@ -6243,6 +6271,59 @@ func (lc *LightningChannel) validateAddHtlc(pd *paymentDescriptor,
return err
}

// In order to avoid unnecessary validations of the aux bandwidth that
// may be costly to perform, let's skip unless this is the final check
// before adding the HTLC to the channel.
if !finalCheck {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of this boolean call it after validateHTLC in addHTLC ?

return nil
}

// If an auxiliary HTLC validator is configured, call it now to perform
// custom validation checks against the current channel state. This is
// the final validation point before the HTLC is added to the update
// log, ensuring that the validator sees the most up-to-date state
// including all previously validated HTLCs in this batch.
//
// NOTE: This is called after the standard commitment sanity checks to
// ensure we only perform (potentially) expensive custom validation on
// HTLCs that have already passed the basic Lightning protocol
// constraints.
err = fn.MapOptionZ(
lc.opts.auxHtlcValidator,
func(validator AuxHtlcValidator) error {
// Fetch the current HTLC view which includes all
// pending HTLCs that haven't been committed yet. This
// provides the validator with the most accurate state.
commitChain := lc.commitChains.Local
remoteIndex := commitChain.tail().messageIndices.Remote
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Thats the remoteACK why not call it like this ?

view := lc.fetchHTLCView(
remoteIndex,
lc.updateLogs.Local.logIndex,
)

nextHeight := lc.commitChains.Local.tip().height + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not quite sure why we need this height ?

view.NextHeight = nextHeight

lc.log.Infof("Setting view nextheight=%v", nextHeight)

auxView := newAuxHtlcView(view)

// Get the current available balance for the link
// bandwidth check. This is needed for the reserve
// validation in the traffic shaper. We use NoBuffer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need an extra reserve check on the traffic shaper level ? It should be done by LND ?

// since this is the final check before adding the HTLC.
linkBandwidth, _ := lc.availableBalance(NoBuffer)

return validator(
pd.Amount, linkBandwidth, pd.CustomRecords,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this pd.Amount hold the sathosi value of the HTLC or the asset value ?

auxView,
)
},
)
if err != nil {
return fmt.Errorf("aux HTLC validation failed: %w", err)
}

return nil
}

Expand Down
86 changes: 86 additions & 0 deletions peer/brontide.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
"github.com/lightningnetwork/lnd/pool"
"github.com/lightningnetwork/lnd/protofsm"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/subscribe"
"github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/tlv"
Expand Down Expand Up @@ -1139,6 +1140,16 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) (
},
)

p.cfg.AuxTrafficShaper.WhenSome(
func(ts htlcswitch.AuxTrafficShaper) {
val := p.createHtlcValidator(dbChan, ts)
chanOpts = append(
chanOpts,
lnwallet.WithAuxHtlcValidator(val),
)
},
)

lnChan, err := lnwallet.NewLightningChannel(
p.cfg.Signer, dbChan, p.cfg.SigPool, chanOpts...,
)
Expand Down Expand Up @@ -5228,6 +5239,15 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error {
chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s))
})

p.cfg.AuxTrafficShaper.WhenSome(
func(ts htlcswitch.AuxTrafficShaper) {
val := p.createHtlcValidator(c.OpenChannel, ts)
chanOpts = append(
chanOpts, lnwallet.WithAuxHtlcValidator(val),
)
},
)

// If not already active, we'll add this channel to the set of active
// channels, so we can look it up later easily according to its channel
// ID.
Expand Down Expand Up @@ -5434,6 +5454,72 @@ func (p *Brontide) scaleTimeout(timeout time.Duration) time.Duration {
return timeout
}

// createHtlcValidator creates an HTLC validator function that performs final
// aux balance validation before HTLCs are added to the channel state. This
// validator calls into the traffic shaper's PaymentBandwidth method to check
// external balance against the most up-to-date channel state, preventing race
// conditions where multiple HTLCs could be approved based on stale bandwidth.
func (p *Brontide) createHtlcValidator(dbChan *channeldb.OpenChannel,
ts htlcswitch.AuxTrafficShaper) lnwallet.AuxHtlcValidator {

return func(amount, linkBandwidth lnwire.MilliSatoshi,
customRecords lnwire.CustomRecords,
view lnwallet.AuxHtlcView) error {

// Get the short channel ID for logging.
scid := dbChan.ShortChannelID

// Extract the HTLC custom records to pass to the traffic
// shaper.
var htlcBlob fn.Option[tlv.Blob]
if len(customRecords) > 0 {
blob, err := customRecords.Serialize()
if err != nil {
return fmt.Errorf("unable to serialize "+
"custom records: %w", err)
}
htlcBlob = fn.Some(blob)
}

// Get the funding and commitment blobs for this channel.
fundingBlob := dbChan.CustomBlob
commitmentBlob := dbChan.LocalCommitment.CustomBlob

// Fetch the peer's public key.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: its not really fetching but converting to a proper format

peerBytes := p.IdentityKey().SerializeCompressed()
peer, err := route.NewVertexFromBytes(peerBytes)
if err != nil {
return fmt.Errorf("failed to create vertex from peer "+
"pub key: %w", err)
}
Comment on lines +5489 to +5494

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block can be simplified by using route.NewVertex. The current implementation serializes the public key to bytes, then route.NewVertexFromBytes parses it back to a public key for validation before converting it to a route.Vertex. Since p.IdentityKey() is guaranteed to return a valid *btcec.PublicKey, we can use route.NewVertex directly. This is slightly more efficient and makes the code cleaner by removing the unnecessary error handling.

Suggested change
peerBytes := p.IdentityKey().SerializeCompressed()
peer, err := route.NewVertexFromBytes(peerBytes)
if err != nil {
return fmt.Errorf("failed to create vertex from peer "+
"pub key: %w", err)
}
peer := route.NewVertex(p.IdentityKey())

// Call the traffic shaper's PaymentBandwidth method with the
// current state. This performs the same bandwidth checks as
// during pathfinding/forwarding, but against the absolute
// latest channel state.
//
// The linkBandwidth is provided by the channel and represents
// the current available balance, which is used by the traffic
// shaper to ensure we don't dip below channel reserves.
bandwidth, err := ts.PaymentBandwidth(
fundingBlob, htlcBlob, commitmentBlob,
linkBandwidth, amount, view, peer,
)
if err != nil {
return fmt.Errorf("traffic shaper bandwidth check "+
"failed: %w", err)
}

if amount > bandwidth {
return fmt.Errorf("insufficient aux bandwidth: "+
"need %v, have %v (scid=%v)", amount,
bandwidth, scid)
}

return nil
}
}

// CoopCloseUpdates is a struct used to communicate updates for an active close
// to the caller.
type CoopCloseUpdates struct {
Expand Down
Loading