@@ -574,3 +574,244 @@ public func generate(
574
574
output: context. tokenizer. decode ( tokens: tokens) ,
575
575
promptTime: promptTime, generateTime: generateTime)
576
576
}
577
+
578
+ /// Generate tokens from an ``LMInput`` and a ``ModelContext``.
579
+ ///
580
+ /// For example:
581
+ ///
582
+ /// ```swift
583
+ /// let generateParameters: GenerateParameters
584
+ /// let input: UserInput
585
+ /// let context: ModelContext
586
+ ///
587
+ /// let lmInput = try context.processor.prepare(input: input)
588
+ /// let result = generate(input: lmInput,
589
+ /// parameters: generateParameters,
590
+ /// context: context) { token in
591
+ /// .more
592
+ /// }
593
+ /// ```
594
+ ///
595
+ /// Internally this constructs a ``TokenIterator`` and calls
596
+ /// ``generate(input:context:iterator:didGenerate:)``
597
+ ///
598
+ /// - Parameters:
599
+ /// - input: prepared language model input
600
+ /// - parameters: parameters controlling the token generation
601
+ /// - context: model context (model and tokenizer)
602
+ /// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
603
+ /// - Returns: Information about the generation
604
+ public func generate(
605
+ input: LMInput , parameters: GenerateParameters , context: ModelContext ,
606
+ didGenerate: ( Int ) -> GenerateDisposition
607
+ ) throws -> GenerateCompletionInfo {
608
+ let iterator = try TokenIterator (
609
+ input: input, model: context. model, parameters: parameters)
610
+ return generate (
611
+ input: input, context: context, iterator: iterator, didGenerate: didGenerate)
612
+ }
613
+
614
+ public func generate(
615
+ input: LMInput , context: ModelContext ,
616
+ iterator: TokenIterator ,
617
+ didGenerate: ( Int ) -> GenerateDisposition
618
+ ) -> GenerateCompletionInfo {
619
+ var start = Date . timeIntervalSinceReferenceDate
620
+ var promptTime : TimeInterval = 0
621
+
622
+ let additionalEOSTokenIds = Set (
623
+ ( context. configuration. extraEOSTokens ?? [ ] )
624
+ . compactMap {
625
+ context. tokenizer. convertTokenToId ( $0)
626
+ } )
627
+
628
+ var tokenCount = 0
629
+
630
+ for token in iterator {
631
+ // Compute the timing for the prompt
632
+ if promptTime == 0 {
633
+ let now = Date . timeIntervalSinceReferenceDate
634
+ promptTime = now - start
635
+ start = now
636
+ }
637
+
638
+ // Check for end-of-sequence tokens
639
+ if token == context. tokenizer. unknownTokenId || token == context. tokenizer. eosTokenId
640
+ || additionalEOSTokenIds. contains ( token)
641
+ {
642
+ break
643
+ }
644
+
645
+ tokenCount += 1
646
+
647
+ // Invoke the callback with the current token
648
+ if didGenerate ( token) == . stop {
649
+ break
650
+ }
651
+ }
652
+
653
+ let now = Date . timeIntervalSinceReferenceDate
654
+ let generateTime = now - start
655
+
656
+ // Synchronize with the stream to ensure tasks are completed
657
+ Stream ( ) . synchronize ( )
658
+
659
+ return GenerateCompletionInfo (
660
+ promptTokenCount: input. text. tokens. size,
661
+ generationTokenCount: tokenCount,
662
+ promptTime: promptTime,
663
+ generationTime: generateTime
664
+ )
665
+ }
666
+
667
+ /// Generates tokens asynchronously using the provided language model input, parameters, and context.
668
+ ///
669
+ /// This function initializes a `TokenIterator` with the given input, model, and generation parameters,
670
+ /// and then streams the token generation process via an `AsyncStream`. The resulting stream yields
671
+ /// instances of the `Generation` enum, which can represent either individual tokens or summary
672
+ /// completion information.
673
+ ///
674
+ /// - Parameters:
675
+ /// - input: The input for the language model.
676
+ /// - parameters: The configuration options for token generation.
677
+ /// - context: The model context, including the model itself and associated tokenizer.
678
+ /// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`)
679
+ /// and completion information (`.info`).
680
+ /// - Throws: An error if the `TokenIterator` initialization fails due to invalid input or model configuration.
681
+ ///
682
+ /// ### Example Usage:
683
+ /// ```swift
684
+ /// // Define the input, parameters, and context for token generation.
685
+ /// let generateParameters: GenerateParameters
686
+ /// let input: UserInput
687
+ /// let context: ModelContext
688
+ ///
689
+ /// let lmInput = try context.processor.prepare(input: input)
690
+ ///
691
+ /// // Call the generate function to get an AsyncStream.
692
+ /// let stream = try generate(input: lmInput, parameters: parameters, context: context)
693
+ ///
694
+ /// // Process the stream asynchronously to handle generated tokens and completion info.
695
+ /// for await generation in stream {
696
+ /// switch generation {
697
+ /// case .token(let token):
698
+ /// print("Generated token: \(context.tokenizer.decode(tokens: [token])")
699
+ /// case .info(let info):
700
+ /// print("Finished: \(info.tokensPerSecond) tokens/s.")
701
+ /// }
702
+ /// }
703
+ /// ```
704
+ public func generate(
705
+ input: LMInput , parameters: GenerateParameters , context: ModelContext
706
+ ) throws -> AsyncStream < Generation > {
707
+ let iterator = try TokenIterator (
708
+ input: input, model: context. model, parameters: parameters)
709
+ return generate (
710
+ input: input, context: context, iterator: iterator)
711
+ }
712
+
713
+ public func generate(
714
+ input: LMInput , context: ModelContext ,
715
+ iterator: TokenIterator
716
+ ) -> AsyncStream < Generation > {
717
+
718
+ AsyncStream { continuation in
719
+
720
+ // Launch a Task to perform iteration asynchronously.
721
+ let task = Task {
722
+ var start = Date . timeIntervalSinceReferenceDate
723
+ var promptTime : TimeInterval = 0
724
+
725
+ let additionalEOSTokenIds = Set (
726
+ ( context. configuration. extraEOSTokens ?? [ ] )
727
+ . compactMap {
728
+ context. tokenizer. convertTokenToId ( $0)
729
+ } )
730
+
731
+ var tokenCount = 0
732
+
733
+ for token in iterator {
734
+
735
+ // Check for cancellation on every loop iteration.
736
+ if Task . isCancelled { break }
737
+
738
+ if promptTime == 0 {
739
+ let now = Date . timeIntervalSinceReferenceDate
740
+ promptTime = now - start
741
+ start = now
742
+ }
743
+
744
+ if token == context. tokenizer. unknownTokenId
745
+ || token == context. tokenizer. eosTokenId
746
+ || additionalEOSTokenIds. contains ( token)
747
+ {
748
+ break
749
+ }
750
+
751
+ tokenCount += 1
752
+ continuation. yield ( . token( token) )
753
+ }
754
+
755
+ let now = Date . timeIntervalSinceReferenceDate
756
+ let generateTime = now - start
757
+
758
+ let info = GenerateCompletionInfo (
759
+ promptTokenCount: input. text. tokens. size,
760
+ generationTokenCount: tokenCount,
761
+ promptTime: promptTime,
762
+ generationTime: generateTime
763
+ )
764
+ continuation. yield ( . info( info) )
765
+
766
+ // Synchronize with the stream to ensure tasks are completed
767
+ Stream ( ) . synchronize ( )
768
+
769
+ // Finalize the stream
770
+ continuation. finish ( )
771
+ }
772
+ // When the consumer cancels (or ends) the stream,
773
+ // cancel our underlying task.
774
+ continuation. onTermination = { _ in
775
+ task. cancel ( )
776
+ }
777
+ }
778
+ }
779
+
780
+ /// Represents metadata and statistics related to token generation.
781
+ ///
782
+ /// Provides information about the number of tokens processed during both the prompt and generation phases, as well as the time taken for each phase.
783
+ public struct GenerateCompletionInfo {
784
+ /// The number of tokens included in the input prompt.
785
+ let promptTokenCount : Int
786
+
787
+ /// The number of tokens generated by the language model.
788
+ let generationTokenCount : Int
789
+
790
+ /// The time interval (in seconds) taken to process the input prompt.
791
+ let promptTime : TimeInterval
792
+
793
+ /// The time interval (in seconds) taken to generate the output tokens.
794
+ let generationTime : TimeInterval
795
+
796
+ /// The number of tokens processed per second during the prompt phase.
797
+ public var promptTokensPerSecond : Double {
798
+ Double ( promptTokenCount) / promptTime
799
+ }
800
+
801
+ /// The number of tokens generated per second during the generation phase.
802
+ public var tokensPerSecond : Double {
803
+ Double ( generationTokenCount) / generationTime
804
+ }
805
+ }
806
+
807
+ /// Represents the different stages or outputs of the token generation process.
808
+ ///
809
+ /// This enum distinguishes between the following:
810
+ /// - `.token`: An individual token generated by the language model.
811
+ /// - `.info`: Metadata and performance statistics about the generation process.
812
+ public enum Generation {
813
+ /// A generated token represented as an integer.
814
+ case token( Int )
815
+ /// Completion information summarizing token counts and performance metrics.
816
+ case info( GenerateCompletionInfo )
817
+ }
0 commit comments