Skip to content

Commit 5029e5b

Browse files
authored
Improved token generation: single-token version, AsyncStream functionality, and supporting structures (#248)
* Make the initializer public for GenerateResult and sampler and processor methods of GenerateParameters. * fix format * Implement Equatable for and . * feat: Add single-token and AsyncStream token generation, with supporting structures * fix: Task cancellation
1 parent b934138 commit 5029e5b

File tree

1 file changed

+241
-0
lines changed

1 file changed

+241
-0
lines changed

Libraries/MLXLMCommon/Evaluate.swift

+241
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,244 @@ public func generate(
574574
output: context.tokenizer.decode(tokens: tokens),
575575
promptTime: promptTime, generateTime: generateTime)
576576
}
577+
578+
/// Generate tokens from an ``LMInput`` and a ``ModelContext``.
579+
///
580+
/// For example:
581+
///
582+
/// ```swift
583+
/// let generateParameters: GenerateParameters
584+
/// let input: UserInput
585+
/// let context: ModelContext
586+
///
587+
/// let lmInput = try context.processor.prepare(input: input)
588+
/// let result = generate(input: lmInput,
589+
/// parameters: generateParameters,
590+
/// context: context) { token in
591+
/// .more
592+
/// }
593+
/// ```
594+
///
595+
/// Internally this constructs a ``TokenIterator`` and calls
596+
/// ``generate(input:context:iterator:didGenerate:)``
597+
///
598+
/// - Parameters:
599+
/// - input: prepared language model input
600+
/// - parameters: parameters controlling the token generation
601+
/// - context: model context (model and tokenizer)
602+
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
603+
/// - Returns: Information about the generation
604+
public func generate(
605+
input: LMInput, parameters: GenerateParameters, context: ModelContext,
606+
didGenerate: (Int) -> GenerateDisposition
607+
) throws -> GenerateCompletionInfo {
608+
let iterator = try TokenIterator(
609+
input: input, model: context.model, parameters: parameters)
610+
return generate(
611+
input: input, context: context, iterator: iterator, didGenerate: didGenerate)
612+
}
613+
614+
public func generate(
615+
input: LMInput, context: ModelContext,
616+
iterator: TokenIterator,
617+
didGenerate: (Int) -> GenerateDisposition
618+
) -> GenerateCompletionInfo {
619+
var start = Date.timeIntervalSinceReferenceDate
620+
var promptTime: TimeInterval = 0
621+
622+
let additionalEOSTokenIds = Set(
623+
(context.configuration.extraEOSTokens ?? [])
624+
.compactMap {
625+
context.tokenizer.convertTokenToId($0)
626+
})
627+
628+
var tokenCount = 0
629+
630+
for token in iterator {
631+
// Compute the timing for the prompt
632+
if promptTime == 0 {
633+
let now = Date.timeIntervalSinceReferenceDate
634+
promptTime = now - start
635+
start = now
636+
}
637+
638+
// Check for end-of-sequence tokens
639+
if token == context.tokenizer.unknownTokenId || token == context.tokenizer.eosTokenId
640+
|| additionalEOSTokenIds.contains(token)
641+
{
642+
break
643+
}
644+
645+
tokenCount += 1
646+
647+
// Invoke the callback with the current token
648+
if didGenerate(token) == .stop {
649+
break
650+
}
651+
}
652+
653+
let now = Date.timeIntervalSinceReferenceDate
654+
let generateTime = now - start
655+
656+
// Synchronize with the stream to ensure tasks are completed
657+
Stream().synchronize()
658+
659+
return GenerateCompletionInfo(
660+
promptTokenCount: input.text.tokens.size,
661+
generationTokenCount: tokenCount,
662+
promptTime: promptTime,
663+
generationTime: generateTime
664+
)
665+
}
666+
667+
/// Generates tokens asynchronously using the provided language model input, parameters, and context.
668+
///
669+
/// This function initializes a `TokenIterator` with the given input, model, and generation parameters,
670+
/// and then streams the token generation process via an `AsyncStream`. The resulting stream yields
671+
/// instances of the `Generation` enum, which can represent either individual tokens or summary
672+
/// completion information.
673+
///
674+
/// - Parameters:
675+
/// - input: The input for the language model.
676+
/// - parameters: The configuration options for token generation.
677+
/// - context: The model context, including the model itself and associated tokenizer.
678+
/// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`)
679+
/// and completion information (`.info`).
680+
/// - Throws: An error if the `TokenIterator` initialization fails due to invalid input or model configuration.
681+
///
682+
/// ### Example Usage:
683+
/// ```swift
684+
/// // Define the input, parameters, and context for token generation.
685+
/// let generateParameters: GenerateParameters
686+
/// let input: UserInput
687+
/// let context: ModelContext
688+
///
689+
/// let lmInput = try context.processor.prepare(input: input)
690+
///
691+
/// // Call the generate function to get an AsyncStream.
692+
/// let stream = try generate(input: lmInput, parameters: parameters, context: context)
693+
///
694+
/// // Process the stream asynchronously to handle generated tokens and completion info.
695+
/// for await generation in stream {
696+
/// switch generation {
697+
/// case .token(let token):
698+
/// print("Generated token: \(context.tokenizer.decode(tokens: [token])")
699+
/// case .info(let info):
700+
/// print("Finished: \(info.tokensPerSecond) tokens/s.")
701+
/// }
702+
/// }
703+
/// ```
704+
public func generate(
705+
input: LMInput, parameters: GenerateParameters, context: ModelContext
706+
) throws -> AsyncStream<Generation> {
707+
let iterator = try TokenIterator(
708+
input: input, model: context.model, parameters: parameters)
709+
return generate(
710+
input: input, context: context, iterator: iterator)
711+
}
712+
713+
public func generate(
714+
input: LMInput, context: ModelContext,
715+
iterator: TokenIterator
716+
) -> AsyncStream<Generation> {
717+
718+
AsyncStream { continuation in
719+
720+
// Launch a Task to perform iteration asynchronously.
721+
let task = Task {
722+
var start = Date.timeIntervalSinceReferenceDate
723+
var promptTime: TimeInterval = 0
724+
725+
let additionalEOSTokenIds = Set(
726+
(context.configuration.extraEOSTokens ?? [])
727+
.compactMap {
728+
context.tokenizer.convertTokenToId($0)
729+
})
730+
731+
var tokenCount = 0
732+
733+
for token in iterator {
734+
735+
// Check for cancellation on every loop iteration.
736+
if Task.isCancelled { break }
737+
738+
if promptTime == 0 {
739+
let now = Date.timeIntervalSinceReferenceDate
740+
promptTime = now - start
741+
start = now
742+
}
743+
744+
if token == context.tokenizer.unknownTokenId
745+
|| token == context.tokenizer.eosTokenId
746+
|| additionalEOSTokenIds.contains(token)
747+
{
748+
break
749+
}
750+
751+
tokenCount += 1
752+
continuation.yield(.token(token))
753+
}
754+
755+
let now = Date.timeIntervalSinceReferenceDate
756+
let generateTime = now - start
757+
758+
let info = GenerateCompletionInfo(
759+
promptTokenCount: input.text.tokens.size,
760+
generationTokenCount: tokenCount,
761+
promptTime: promptTime,
762+
generationTime: generateTime
763+
)
764+
continuation.yield(.info(info))
765+
766+
// Synchronize with the stream to ensure tasks are completed
767+
Stream().synchronize()
768+
769+
// Finalize the stream
770+
continuation.finish()
771+
}
772+
// When the consumer cancels (or ends) the stream,
773+
// cancel our underlying task.
774+
continuation.onTermination = { _ in
775+
task.cancel()
776+
}
777+
}
778+
}
779+
780+
/// Represents metadata and statistics related to token generation.
781+
///
782+
/// Provides information about the number of tokens processed during both the prompt and generation phases, as well as the time taken for each phase.
783+
public struct GenerateCompletionInfo {
784+
/// The number of tokens included in the input prompt.
785+
let promptTokenCount: Int
786+
787+
/// The number of tokens generated by the language model.
788+
let generationTokenCount: Int
789+
790+
/// The time interval (in seconds) taken to process the input prompt.
791+
let promptTime: TimeInterval
792+
793+
/// The time interval (in seconds) taken to generate the output tokens.
794+
let generationTime: TimeInterval
795+
796+
/// The number of tokens processed per second during the prompt phase.
797+
public var promptTokensPerSecond: Double {
798+
Double(promptTokenCount) / promptTime
799+
}
800+
801+
/// The number of tokens generated per second during the generation phase.
802+
public var tokensPerSecond: Double {
803+
Double(generationTokenCount) / generationTime
804+
}
805+
}
806+
807+
/// Represents the different stages or outputs of the token generation process.
808+
///
809+
/// This enum distinguishes between the following:
810+
/// - `.token`: An individual token generated by the language model.
811+
/// - `.info`: Metadata and performance statistics about the generation process.
812+
public enum Generation {
813+
/// A generated token represented as an integer.
814+
case token(Int)
815+
/// Completion information summarizing token counts and performance metrics.
816+
case info(GenerateCompletionInfo)
817+
}

0 commit comments

Comments
 (0)