Skip to content

Commit f6830fa

Browse files
committed
fix: move validateToRawMessage() out of validateReceivedMessage()
1 parent 95da5e3 commit f6830fa

File tree

2 files changed

+109
-115
lines changed

2 files changed

+109
-115
lines changed

src/index.ts

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ import {
5454
type MessageId,
5555
type PublishOpts
5656
} from './types.js'
57-
import { buildRawMessage, validateToRawMessage } from './utils/buildRawMessage.js'
57+
import { buildRawMessage, validateStrictNoSignMessage, validateStrictSignMessage, type ValidationResult } from './utils/buildRawMessage.js'
5858
import { createGossipRpc, ensureControl } from './utils/create-gossip-rpc.js'
5959
import { shuffle, messageIdToString } from './utils/index.js'
6060
import { msgIdFnStrictNoSign, msgIdFnStrictSign } from './utils/msgIdFn.js'
@@ -1214,28 +1214,40 @@ export class GossipSub extends TypedEventEmitter<GossipsubEvents> implements Pub
12141214
private async handleReceivedMessage (from: PeerId, rpcMsg: RPC.Message): Promise<void> {
12151215
this.metrics?.onMsgRecvPreValidation(rpcMsg.topic)
12161216

1217-
let validationResult = await this.validateReceivedMessage(from, rpcMsg)
1218-
1219-
if (validationResult.code === MessageStatus.valid) {
1220-
// (Optional) Provide custom validation here with dynamic validators per topic
1221-
// NOTE: This custom topicValidator() must resolve fast (< 100ms) to allow scores
1222-
// to not penalize peers for long validation times.
1223-
const msgIdStr = validationResult.messageId.msgIdStr
1224-
const topicValidator = this.topicValidators.get(rpcMsg.topic)
1225-
if (topicValidator != null) {
1226-
let acceptance: TopicValidatorResult
1227-
// Use try {} catch {} in case topicValidator() is synchronous
1228-
try {
1229-
acceptance = await topicValidator(from, validationResult.msg)
1230-
} catch (e) {
1231-
const errCode = (e as { code: string }).code
1232-
if (errCode === constants.ERR_TOPIC_VALIDATOR_IGNORE) acceptance = TopicValidatorResult.Ignore
1233-
if (errCode === constants.ERR_TOPIC_VALIDATOR_REJECT) acceptance = TopicValidatorResult.Reject
1234-
else acceptance = TopicValidatorResult.Ignore
1235-
}
1217+
let validationResult: ReceivedMessageResult
1218+
// Fast message ID stuff
1219+
const fastMsgIdStr = this.fastMsgIdFn?.(rpcMsg)
1220+
const msgIdCached = fastMsgIdStr !== undefined ? this.fastMsgIdCache?.get(fastMsgIdStr) : undefined
12361221

1237-
if (acceptance !== TopicValidatorResult.Accept) {
1238-
validationResult = { code: MessageStatus.invalid, reason: rejectReasonFromAcceptance(acceptance), msgIdStr }
1222+
if (msgIdCached != null) {
1223+
// This message has been seen previously. Ignore it
1224+
validationResult = { code: MessageStatus.duplicate, msgIdStr: msgIdCached }
1225+
} else {
1226+
const rawValidationResult = this.globalSignaturePolicy === StrictNoSign ? validateStrictNoSignMessage(rpcMsg) : await validateStrictSignMessage(rpcMsg)
1227+
// Perform basic validation on message and convert to RawGossipsubMessage for fastMsgIdFn()
1228+
validationResult = this.validateReceivedMessage(from, rpcMsg, fastMsgIdStr, rawValidationResult)
1229+
1230+
if (validationResult.code === MessageStatus.valid) {
1231+
// (Optional) Provide custom validation here with dynamic validators per topic
1232+
// NOTE: This custom topicValidator() must resolve fast (< 100ms) to allow scores
1233+
// to not penalize peers for long validation times.
1234+
const msgIdStr = validationResult.messageId.msgIdStr
1235+
const topicValidator = this.topicValidators.get(rpcMsg.topic)
1236+
if (topicValidator != null) {
1237+
let acceptance: TopicValidatorResult
1238+
// Use try {} catch {} in case topicValidator() is synchronous
1239+
try {
1240+
acceptance = await topicValidator(from, validationResult.msg)
1241+
} catch (e) {
1242+
const errCode = (e as { code: string }).code
1243+
if (errCode === constants.ERR_TOPIC_VALIDATOR_IGNORE) acceptance = TopicValidatorResult.Ignore
1244+
if (errCode === constants.ERR_TOPIC_VALIDATOR_REJECT) acceptance = TopicValidatorResult.Reject
1245+
else acceptance = TopicValidatorResult.Ignore
1246+
}
1247+
1248+
if (acceptance !== TopicValidatorResult.Accept) {
1249+
validationResult = { code: MessageStatus.invalid, reason: rejectReasonFromAcceptance(acceptance), msgIdStr }
1250+
}
12391251
}
12401252
}
12411253
}
@@ -1316,22 +1328,12 @@ export class GossipSub extends TypedEventEmitter<GossipsubEvents> implements Pub
13161328
* Handles a newly received message from an RPC.
13171329
* May forward to all peers in the mesh.
13181330
*/
1319-
private async validateReceivedMessage (
1331+
private validateReceivedMessage (
13201332
propagationSource: PeerId,
1321-
rpcMsg: RPC.Message
1322-
): Promise<ReceivedMessageResult> {
1323-
// Fast message ID stuff
1324-
const fastMsgIdStr = this.fastMsgIdFn?.(rpcMsg)
1325-
const msgIdCached = fastMsgIdStr !== undefined ? this.fastMsgIdCache?.get(fastMsgIdStr) : undefined
1326-
1327-
if (msgIdCached != null) {
1328-
// This message has been seen previously. Ignore it
1329-
return { code: MessageStatus.duplicate, msgIdStr: msgIdCached }
1330-
}
1331-
1332-
// Perform basic validation on message and convert to RawGossipsubMessage for fastMsgIdFn()
1333-
const validationResult = await validateToRawMessage(this.globalSignaturePolicy, rpcMsg)
1334-
1333+
rpcMsg: RPC.Message,
1334+
fastMsgIdStr: string | number | undefined,
1335+
validationResult: ValidationResult
1336+
): ReceivedMessageResult {
13351337
if (!validationResult.valid) {
13361338
return { code: MessageStatus.invalid, reason: RejectReason.Error, error: validationResult.error }
13371339
}

src/utils/buildRawMessage.ts

Lines changed: 70 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { randomBytes } from '@libp2p/crypto'
22
import { publicKeyFromProtobuf } from '@libp2p/crypto/keys'
3-
import { StrictSign, StrictNoSign, type Message, type PublicKey, type PeerId } from '@libp2p/interface'
3+
import { type Message, type PublicKey, type PeerId } from '@libp2p/interface'
44
import { peerIdFromMultihash } from '@libp2p/peer-id'
55
import * as Digest from 'multiformats/hashes/digest'
66
import { concat as uint8ArrayConcat } from 'uint8arrays/concat'
@@ -80,92 +80,84 @@ export async function buildRawMessage (
8080

8181
export type ValidationResult = { valid: true, message: Message } | { valid: false, error: ValidateError }
8282

83-
export async function validateToRawMessage (
84-
signaturePolicy: typeof StrictNoSign | typeof StrictSign,
83+
export function validateStrictNoSignMessage (
84+
msg: RPC.Message
85+
): ValidationResult {
86+
if (msg.signature != null) return { valid: false, error: ValidateError.SignaturePresent }
87+
if (msg.seqno != null) return { valid: false, error: ValidateError.SeqnoPresent }
88+
if (msg.from != null) return { valid: false, error: ValidateError.FromPresent }
89+
90+
return { valid: true, message: { type: 'unsigned', topic: msg.topic, data: msg.data ?? new Uint8Array(0) } }
91+
}
92+
93+
export async function validateStrictSignMessage (
8594
msg: RPC.Message
8695
): Promise<ValidationResult> {
87-
// If strict-sign, verify all
88-
// If anonymous (no-sign), ensure no preven
89-
90-
switch (signaturePolicy) {
91-
case StrictNoSign:
92-
if (msg.signature != null) return { valid: false, error: ValidateError.SignaturePresent }
93-
if (msg.seqno != null) return { valid: false, error: ValidateError.SeqnoPresent }
94-
if (msg.from != null) return { valid: false, error: ValidateError.FromPresent }
95-
96-
return { valid: true, message: { type: 'unsigned', topic: msg.topic, data: msg.data ?? new Uint8Array(0) } }
97-
98-
case StrictSign: {
99-
// Verify seqno
100-
if (msg.seqno == null) return { valid: false, error: ValidateError.InvalidSeqno }
101-
if (msg.seqno.length !== 8) {
102-
return { valid: false, error: ValidateError.InvalidSeqno }
103-
}
96+
// Verify seqno
97+
if (msg.seqno == null) return { valid: false, error: ValidateError.InvalidSeqno }
98+
if (msg.seqno.length !== 8) {
99+
return { valid: false, error: ValidateError.InvalidSeqno }
100+
}
104101

105-
if (msg.signature == null) return { valid: false, error: ValidateError.InvalidSignature }
106-
if (msg.from == null) return { valid: false, error: ValidateError.InvalidPeerId }
102+
if (msg.signature == null) return { valid: false, error: ValidateError.InvalidSignature }
103+
if (msg.from == null) return { valid: false, error: ValidateError.InvalidPeerId }
107104

108-
let fromPeerId: PeerId
109-
try {
110-
// TODO: Fix PeerId types
111-
fromPeerId = peerIdFromMultihash(Digest.decode(msg.from))
112-
} catch (e) {
113-
return { valid: false, error: ValidateError.InvalidPeerId }
114-
}
105+
let fromPeerId: PeerId
106+
try {
107+
// TODO: Fix PeerId types
108+
fromPeerId = peerIdFromMultihash(Digest.decode(msg.from))
109+
} catch (e) {
110+
return { valid: false, error: ValidateError.InvalidPeerId }
111+
}
115112

116-
// - check from defined
117-
// - transform source to PeerId
118-
// - parse signature
119-
// - get .key, else from source
120-
// - check key == source if present
121-
// - verify sig
122-
123-
let publicKey: PublicKey
124-
if (msg.key != null) {
125-
publicKey = publicKeyFromProtobuf(msg.key)
126-
// TODO: Should `fromPeerId.pubKey` be optional?
127-
if (fromPeerId.publicKey !== undefined && !publicKey.equals(fromPeerId.publicKey)) {
128-
return { valid: false, error: ValidateError.InvalidPeerId }
129-
}
130-
} else {
131-
if (fromPeerId.publicKey == null) {
132-
return { valid: false, error: ValidateError.InvalidPeerId }
133-
}
134-
publicKey = fromPeerId.publicKey
135-
}
113+
// - check from defined
114+
// - transform source to PeerId
115+
// - parse signature
116+
// - get .key, else from source
117+
// - check key == source if present
118+
// - verify sig
119+
120+
let publicKey: PublicKey
121+
if (msg.key != null) {
122+
publicKey = publicKeyFromProtobuf(msg.key)
123+
// TODO: Should `fromPeerId.pubKey` be optional?
124+
if (fromPeerId.publicKey !== undefined && !publicKey.equals(fromPeerId.publicKey)) {
125+
return { valid: false, error: ValidateError.InvalidPeerId }
126+
}
127+
} else {
128+
if (fromPeerId.publicKey == null) {
129+
return { valid: false, error: ValidateError.InvalidPeerId }
130+
}
131+
publicKey = fromPeerId.publicKey
132+
}
136133

137-
const rpcMsgPreSign: RPC.Message = {
138-
from: msg.from,
139-
data: msg.data,
140-
seqno: msg.seqno,
141-
topic: msg.topic,
142-
signature: undefined, // Exclude signature field for signing
143-
key: undefined // Exclude key field for signing
144-
}
134+
const rpcMsgPreSign: RPC.Message = {
135+
from: msg.from,
136+
data: msg.data,
137+
seqno: msg.seqno,
138+
topic: msg.topic,
139+
signature: undefined, // Exclude signature field for signing
140+
key: undefined // Exclude key field for signing
141+
}
145142

146-
// Get the message in bytes, and prepend with the pubsub prefix
147-
// the signature is over the bytes "libp2p-pubsub:<protobuf-message>"
148-
const bytes = uint8ArrayConcat([SignPrefix, RPC.Message.encode(rpcMsgPreSign)])
143+
// Get the message in bytes, and prepend with the pubsub prefix
144+
// the signature is over the bytes "libp2p-pubsub:<protobuf-message>"
145+
const bytes = uint8ArrayConcat([SignPrefix, RPC.Message.encode(rpcMsgPreSign)])
149146

150-
if (!(await publicKey.verify(bytes, msg.signature))) {
151-
return { valid: false, error: ValidateError.InvalidSignature }
152-
}
147+
if (!(await publicKey.verify(bytes, msg.signature))) {
148+
return { valid: false, error: ValidateError.InvalidSignature }
149+
}
153150

154-
return {
155-
valid: true,
156-
message: {
157-
type: 'signed',
158-
from: fromPeerId,
159-
data: msg.data ?? new Uint8Array(0),
160-
sequenceNumber: BigInt(`0x${uint8ArrayToString(msg.seqno, 'base16')}`),
161-
topic: msg.topic,
162-
signature: msg.signature,
163-
key: publicKey
164-
}
165-
}
151+
return {
152+
valid: true,
153+
message: {
154+
type: 'signed',
155+
from: fromPeerId,
156+
data: msg.data ?? new Uint8Array(0),
157+
sequenceNumber: BigInt(`0x${uint8ArrayToString(msg.seqno, 'base16')}`),
158+
topic: msg.topic,
159+
signature: msg.signature,
160+
key: publicKey
166161
}
167-
168-
default:
169-
throw new Error('Unreachable')
170162
}
171163
}

0 commit comments

Comments
 (0)