Skip to content

Commit

Permalink
WIP: Use external transcription-service for transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
philmcmahon committed Mar 12, 2024
1 parent efcd5ae commit 46377d9
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 21 deletions.
11 changes: 9 additions & 2 deletions backend/app/AppComponents.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import org.apache.pekko.actor.{ActorSystem, CoordinatedShutdown}
import org.apache.pekko.actor.CoordinatedShutdown.Reason
import cats.syntax.either._
import com.amazonaws.services.sqs.{AmazonSQSClient, AmazonSQSClientBuilder}
import com.gu.pandomainauth
import com.gu.pandomainauth.PublicSettings
import controllers.AssetsComponents
Expand All @@ -15,7 +16,7 @@ import extraction.email.olm.OlmEmailExtractor
import extraction.email.pst.PstEmailExtractor
import extraction.ocr.{ImageOcrExtractor, OcrMyPdfExtractor, OcrMyPdfImageExtractor, TesseractPdfOcrExtractor}
import extraction.tables.{CsvTableExtractor, ExcelTableExtractor}
import extraction.{DocumentBodyExtractor, MimeTypeMapper, TranscriptionExtractor, Worker}
import extraction.{DocumentBodyExtractor, ExternalTranscriptionExtractor, MimeTypeMapper, TranscriptionExtractor, Worker}
import ingestion.phase2.IngestStorePolling
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.neo4j.driver.v1.{AuthTokens, GraphDatabase}
Expand Down Expand Up @@ -74,6 +75,7 @@ class AppComponents(context: Context, config: Config)
val ingestionExecutionContext = actorSystem.dispatchers.lookup("ingestion-context")

val s3Client = new S3Client(config.s3)(s3ExecutionContext)
val sqsClient = AmazonSQSClientBuilder.standard().withRegion(config.sqs.region).build()

val workerName = config.worker.name.getOrElse(InetAddress.getLocalHost.getHostName)

Expand Down Expand Up @@ -150,7 +152,12 @@ class AppComponents(context: Context, config: Config)
val imageOcrExtractor = new ImageOcrExtractor(config.ocr, scratchSpace, esResources, ingestionServices)
val ocrMyPdfImageExtractor = new OcrMyPdfImageExtractor(config.ocr, scratchSpace, esResources, previewStorage, ingestionServices)

val transcriptionExtractor = new TranscriptionExtractor(esResources, scratchSpace, config.transcribe)

val transcriptionExtractor = if (config.worker.useExternalExtractors) {
new ExternalTranscriptionExtractor(esResources, config.transcribe, blobStorage, sqsClient)
} else {
new TranscriptionExtractor(esResources, scratchSpace, config.transcribe)
}

val ocrExtractors = config.ocr.defaultEngine match {
case OcrEngine.OcrMyPdf => List(ocrMyPdfExtractor, ocrMyPdfImageExtractor)
Expand Down
26 changes: 26 additions & 0 deletions backend/app/extraction/ExternalExtractor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package extraction

import model.manifest.Blob
import utils.attempt.Failure

import java.io.InputStream

/**
* External Extractors are where the actual extraction doesn't take place on the worker but in some third party service
* The behaviour is a little different as we need to trigger the extraction, then the worker can get on with other tasks
* whilst waiting for a response from the third party service. Once the response comes in we need to store the data
* and update the manifest to mark the extraction as complete
*/
abstract class ExternalExtractor extends Extractor {

override def external = true

final override def extract(blob: Blob, inputStream: InputStream, params: ExtractionParams): Either[Failure, Unit] = {
triggerExtraction(blob, params)
}

def triggerExtraction(blob: Blob, params: ExtractionParams): Either[Failure, Unit]

def pollForResults(): Either[Failure, Unit]

}
124 changes: 124 additions & 0 deletions backend/app/extraction/ExternalTranscriptionExtractor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package extraction

import cats.syntax.either._
import com.amazonaws.services.sqs.{AmazonSQS, AmazonSQSClient}
import com.amazonaws.services.sqs.model.ReceiveMessageRequest
import model.manifest.Blob
import model.{English, Languages, Uri}
import org.apache.commons.io.FileUtils
import org.joda.time.DateTime
import play.api.libs.json.Json
import services.{ObjectStorage, ScratchSpace, TranscribeConfig, WorkerConfig}
import services.index.Index
import utils.FfMpeg.FfMpegSubprocessCrashedException
import utils.attempt.{Failure, FfMpegFailure, UnknownFailure}
import utils._

import scala.concurrent.ExecutionContext
import java.io.File
import java.nio.charset.StandardCharsets
import java.util.UUID


case class OutputBucketUrls(txt: String, srt: String, json: String)
case class TranscriptionJob(id: String, originalFilename: String, inputSignedUrl: String, sentTimestamp: String,
userEmail: String, transcriptDestinationService: String, outputBucketUrls: OutputBucketUrls,
languageCode: String)
object OutputBucketUrls {
implicit val formats = Json.format[OutputBucketUrls]

}
object TranscriptionJob {
implicit val formats = Json.format[TranscriptionJob]
}

/**
* id: z.string(),
* originalFilename: z.string(),
* userEmail: z.string(),
* status: z.literal('SUCCESS'),
* languageCode: z.string(),
* outputBucketKeys: OutputBucketKeys,
*/

case class TranscriptionOutput(id: String, originalFilename: String, userEmail: String, status: String, languageCode: String, outputBucketUrls: OutputBucketUrls)
object TranscriptionOutput {
implicit val formats = Json.format[TranscriptionOutput]
}

class ExternalTranscriptionExtractor(index: Index, transcribeConfig: TranscribeConfig, blobStorage: ObjectStorage, amazonSQSClient: AmazonSQS)(implicit executionContext: ExecutionContext) extends ExternalExtractor with Logging {
val mimeTypes: Set[String] = Set(
"audio/wav",
"audio/vnd.wave",
"audio/x-aiff", // converted and transcribed. But preview doesn't work
"audio/mpeg",
"audio/aac", // tika can't detect this!!
"audio/vorbis", // Converted by ffmpeg but failed in whisper
"audio/opus",
"audio/amr", // converted and transcribed. But preview doesn't work
"audio/amr-wb", // Couldn't find a sample to test
"audio/x-caf", // Couldn't find a sample to test
"audio/mp4", // Couldn't find a sample to test
"audio/x-ms-wma", // converted and transcribed. But preview doesn't work
"video/3gpp",
"video/mp4", // quicktime detected for some of mp4 samples
"video/quicktime",
"video/x-flv", // converted and transcribed. But preview doesn't work
"video/x-ms-wmv", // converted and transcribed. But preview doesn't work
"video/x-msvideo", // converted and transcribed. But preview doesn't work
"video/x-m4v",
"video/mpeg" // converted and transcribed. But preview doesn't work
)

def canProcessMimeType: String => Boolean = mimeTypes.contains

override def indexing = true
// set a low priority as transcription takes a long time, we don't want to block up the workers
override def priority = 2

private val dataBucketPrefix = "transcription-service-output-data"
private def getOutputBucketUrls(blobUri: String): OutputBucketUrls = {
val txt = s"$dataBucketPrefix/$blobUri.txt"
// we should find a way to avoid having to provide these
val srt = s"$dataBucketPrefix/$blobUri.srt"
val json = s"$dataBucketPrefix/$blobUri.json"
OutputBucketUrls(txt, srt, json)
}

private def postToTranscriptionQueue(blobUri: String, signedUrl: String) = {
val transcriptionJob = TranscriptionJob(UUID.randomUUID().toString, blobUri, signedUrl, DateTime.now().toString, "giant", "Giant",
getOutputBucketUrls(blobUri), "")
amazonSQSClient.sendMessage(transcribeConfig.transcriptionServiceQueueUrl, Json.stringify(Json.toJson(transcriptionJob)))
}

override def triggerExtraction(blob: Blob, params: ExtractionParams): Either[Failure, Unit] = {
blobStorage.getSignedUrl (blob.uri.value).map {
url => postToTranscriptionQueue(blob.uri.value, url)
}
}
import scala.jdk.CollectionConverters._

override def pollForResults(): Either[Failure, Unit] = {
val messages = amazonSQSClient.receiveMessage(
new ReceiveMessageRequest(transcribeConfig.transcriptionServiceOutputQueueUrl).withMaxNumberOfMessages(10)
).getMessages

messages.asScala.toList.foreach { message =>
val transcriptionOutput = Json.parse(message.getBody).as[TranscriptionOutput]
blobStorage.get(transcriptionOutput.outputBucketUrls.txt).map { inputStream =>
val txt = new String(inputStream.readAllBytes(), StandardCharsets.UTF_8)
index.addDocumentTranscription(Uri(transcriptionOutput.originalFilename), txt, None, Languages.getByIso6391Code(transcriptionOutput.languageCode).getOrElse(English))
.recoverWith {
case _ =>
val msg = s"Failed to write transcript result to elasticsearch. Transcript language: ${transcriptionOutput.languageCode}"
logger.error(msg)
// throw the error - will be caught by catchNonFatal
throw new Error(msg)
}

}
}
Right(())
}

}
2 changes: 2 additions & 0 deletions backend/app/extraction/Extractor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ trait Extractor {
def cost(mimeType: MimeType, size: Long): Long = size

def extract(blob: Blob, inputStream: InputStream, params: ExtractionParams): Either[Failure, Unit]

def external: Boolean = false
}
39 changes: 24 additions & 15 deletions backend/app/services/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ case class OcrConfig(

case class TranscribeConfig(
whisperModelFilename: String,
transcriptionServiceQueueUrl: String,
transcriptionServiceOutputQueueUrl: String
)

case class WorkerConfig(
Expand All @@ -90,7 +92,8 @@ case class WorkerConfig(
controlInterval: FiniteDuration,
controlCooldown: FiniteDuration,
enabled: Boolean,
workspace: String
workspace: String,
useExternalExtractors: Boolean
)

case class Neo4jQueryLoggingConfig(
Expand Down Expand Up @@ -154,6 +157,10 @@ case class S3Config(
sseAlgorithm: Option[String]
)

case class SQSConfig(
region: String
)

case class BucketConfig(
ingestion: String,
deadLetter: String,
Expand All @@ -173,19 +180,20 @@ case class AWSDiscoveryConfig(
)

case class Config(
underlying: com.typesafe.config.Config,
app: AppConfig,
auth: AuthConfig,
worker: WorkerConfig,
neo4j: Neo4jConfig,
postgres: Option[PostgresConfig],
elasticsearch: ElasticsearchConfig,
ingestion: IngestConfig,
preview: PreviewConfig,
s3: S3Config,
aws: Option[AWSDiscoveryConfig],
ocr: OcrConfig,
transcribe: TranscribeConfig
underlying: com.typesafe.config.Config,
app: AppConfig,
auth: AuthConfig,
worker: WorkerConfig,
neo4j: Neo4jConfig,
postgres: Option[PostgresConfig],
elasticsearch: ElasticsearchConfig,
ingestion: IngestConfig,
preview: PreviewConfig,
s3: S3Config,
aws: Option[AWSDiscoveryConfig],
ocr: OcrConfig,
transcribe: TranscribeConfig,
sqs: SQSConfig
)

object Config {
Expand All @@ -202,7 +210,8 @@ object Config {
raw.as[S3Config]("s3"),
raw.as[Option[AWSDiscoveryConfig]]("aws"),
raw.as[OcrConfig]("ocr"),
raw.as[TranscribeConfig]("transcribe")
raw.as[TranscribeConfig]("transcribe"),
raw.as[SQSConfig]("sqs")
)

private def parseAuth(rawAuthConfig: com.typesafe.config.Config): AuthConfig = {
Expand Down
12 changes: 11 additions & 1 deletion backend/app/services/ObjectStorage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ import com.amazonaws.services.s3.model.{DeleteObjectsRequest, ListObjectsRequest
import java.io.InputStream
import java.nio.file.Path
import model.ObjectMetadata
import org.joda.time.DateTime
import utils.attempt.{Failure, IllegalStateFailure, UnknownFailure}
import utils.aws.{AwsErrors, S3Client}
import scala.jdk.CollectionConverters._

import java.util.Date
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

trait ObjectStorage {
def create(key: String, path: Path, mimeType: Option[String] = None): Either[Failure, Unit]
def get(key: String): Either[Failure, InputStream]
def getSignedUrl(key: String): Either[Failure, String]
def getMetadata(key: String): Either[Failure, ObjectMetadata]
def delete(key: String): Either[Failure, Unit]
def deleteMultiple(key: Set[String]): Either[Failure, Unit]
Expand All @@ -30,6 +33,13 @@ class S3ObjectStorage private(client: S3Client, bucket: String) extends ObjectSt
run(client.aws.getObject(bucket, key).getObjectContent)
}

def getSignedUrl(key: String): Either[Failure, String] = {

val thisTimeTomorrow = new DateTime().plusDays(1)

run(client.aws.generatePresignedUrl(bucket, key,thisTimeTomorrow.toDate).toString)
}

def getMetadata(key: String): Either[Failure, ObjectMetadata] = run {
val stats = client.aws.getObjectMetadata(bucket, key)
ObjectMetadata(stats.getContentLength, stats.getContentType)
Expand Down
7 changes: 4 additions & 3 deletions backend/app/utils/AwsDiscovery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ object AwsDiscovery extends Logging {
// Using the instanceId as the worker name will allow us to break locks on terminated instances in the future
worker = maybeInstanceId.map { instanceId =>
config.worker.copy(
name = Some(instanceId)
)
name = Some(instanceId))
}.getOrElse(config.worker),
transcribe = config.transcribe.copy(
whisperModelFilename = readSSMParameter("transcribe/modelFilename", stack, stage, ssmClient)
whisperModelFilename = readSSMParameter("transcribe/modelFilename", stack, stage, ssmClient),
transcriptionServiceOutputQueueUrl = readSSMParameter("transcribe/transcriptionServiceOutputQueueUrl", stack, stage, ssmClient),
transcriptionServiceQueueUrl = readSSMParameter("transcribe/transcriptionServiceQueueUrl", stack, stage, ssmClient)
),
underlying = config.underlying
.withValue("play.http.secret.key", fromAnyRef(readSSMParameter("pfi/playSecret", stack, stage, ssmClient)))
Expand Down
3 changes: 3 additions & 0 deletions backend/conf/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ worker {
controlCooldown = "5 minutes"
enabled = true
workspace = "/tmp"
useExternalExtractors = true
}

neo4j {
Expand Down Expand Up @@ -205,6 +206,8 @@ ocr {

transcribe {
whisperModelFilename = "ggml-base.bin"
transcriptionServiceQueueUrl = ""
transcriptionServiceOutputQueueUrl = ""
}

# This will overwrite some settings from above
Expand Down
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ lazy val backend = (project in file("backend"))
"com.amazonaws" % "aws-java-sdk-autoscaling" % awsVersion,
"com.amazonaws" % "aws-java-sdk-cloudwatch" % awsVersion,
"com.amazonaws" % "aws-java-sdk-cloudwatchmetrics" % awsVersion,
"com.amazonaws" % "aws-java-sdk-sqs" % awsVersion,
"com.beachape" %% "enumeratum-play" % "1.8.0",
"com.iheart" %% "ficus" % "1.5.2",
"org.jsoup" % "jsoup" % "1.14.2",
Expand Down

0 comments on commit 46377d9

Please sign in to comment.