Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass Tags using an explicit form field #227

Merged
merged 1 commit into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions clients/dotnet/WebClient/MemoryWebClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ private async Task<string> ImportInternalAsync(
using (StringContent documentIdContent = new(uploadRequest.DocumentId))
{
List<IDisposable> disposables = new();
formData.Add(documentIdContent, Constants.WebServiceDocumentIdField);
formData.Add(indexContent, Constants.WebServiceIndexField);
formData.Add(documentIdContent, Constants.WebServiceDocumentIdField);

// Add steps to the form
foreach (string? step in uploadRequest.Steps)
Expand All @@ -336,9 +336,9 @@ private async Task<string> ImportInternalAsync(
// Add tags to the form
foreach (KeyValuePair<string, string?> tag in uploadRequest.Tags.Pairs)
{
var tagContent = new StringContent(tag.Value);
var tagContent = new StringContent($"{tag.Key}{Constants.ReservedEqualsChar}{tag.Value}");
disposables.Add(tagContent);
formData.Add(tagContent, tag.Key);
formData.Add(tagContent, Constants.WebServiceTagsField);
}

// Add files to the form
Expand All @@ -360,7 +360,6 @@ private async Task<string> ImportInternalAsync(
try
{
HttpResponseMessage? response = await this._client.PostAsync("/upload", formData, cancellationToken).ConfigureAwait(false);
formData.Dispose();
response.EnsureSuccessStatusCode();
}
catch (HttpRequestException e) when (e.Data.Contains("StatusCode"))
Expand All @@ -373,6 +372,7 @@ private async Task<string> ImportInternalAsync(
}
finally
{
formData.Dispose();
foreach (var disposable in disposables)
{
disposable.Dispose();
Expand Down
2 changes: 1 addition & 1 deletion nuget-package.props
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project>
<PropertyGroup>
<!-- Central version prefix - applies to all nuget packages. -->
<Version>0.24.0</Version>
<Version>0.25.0</Version>

<!-- These are set at the project level-->
<IsPackable>false</IsPackable>
Expand Down
3 changes: 3 additions & 0 deletions service/Abstractions/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ public static class Constants
// Form field containing the Document ID
public const string WebServiceDocumentIdField = "documentId";

// Form field containing the list of tags
public const string WebServiceTagsField = "tags";

// Form field containing the list of pipeline steps
public const string WebServiceStepsField = "steps";

Expand Down
6 changes: 6 additions & 0 deletions service/Abstractions/Models/TagCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,11 @@ private static void ValidateKey(string key)
{
throw new KernelMemoryException("A tag name cannot contain the '=' char");
}

// ':' is reserved for backward/forward compatibility
if (key.Contains(':'))
{
throw new KernelMemoryException("A tag name cannot contain the ':' char");
}
}
}
68 changes: 34 additions & 34 deletions service/Core/WebService/HttpDocumentUploadRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ public class HttpDocumentUploadRequest
public static async Task<(HttpDocumentUploadRequest model, bool isValid, string errMsg)> BindHttpRequestAsync(
HttpRequest httpRequest, CancellationToken cancellationToken = default)
{
const string IndexField = Constants.WebServiceIndexField;
const string DocumentIdField = Constants.WebServiceDocumentIdField;
const string StepsField = Constants.WebServiceStepsField;

var result = new HttpDocumentUploadRequest();

// Content format validation
Expand All @@ -50,71 +46,75 @@ public class HttpDocumentUploadRequest
return (result, false, "No file was uploaded");
}

if (form.TryGetValue(IndexField, out StringValues indexes) && indexes.Count > 1)
// Only one index can be defined
if (form.TryGetValue(Constants.WebServiceIndexField, out StringValues indexes) && indexes.Count > 1)
{
return (result, false, $"Invalid index name, '{IndexField}', multiple values provided");
return (result, false, $"Invalid index name, '{Constants.WebServiceIndexField}', multiple values provided");
}

if (form.TryGetValue(DocumentIdField, out StringValues documentIds) && documentIds.Count > 1)
// Only one document ID can be defined
if (form.TryGetValue(Constants.WebServiceDocumentIdField, out StringValues documentIds) && documentIds.Count > 1)
{
return (result, false, $"Invalid document ID, '{DocumentIdField}' must be a single value, not a list");
return (result, false, $"Invalid document ID, '{Constants.WebServiceDocumentIdField}' must be a single value, not a list");
}

// Document Id is optional, e.g. used if the client wants to retry the same upload, otherwise we generate a random/unique one
var documentId = documentIds.FirstOrDefault();
// Document Id is optional, e.g. used if the client wants to retry the same upload, otherwise a random/unique one is generated
string? documentId = documentIds.FirstOrDefault();
if (string.IsNullOrWhiteSpace(documentId))
{
documentId = DateTimeOffset.Now.ToString("yyyyMMdd.HHmmss.", CultureInfo.InvariantCulture) + Guid.NewGuid().ToString("N");
}

// Optional document tags. Tags are passed in as "key:value", where a key can have multiple values. See TagCollection.
if (form.TryGetValue(Constants.WebServiceTagsField, out StringValues tags))
{
foreach (string? tag in tags)
{
if (tag == null) { continue; }

var keyValue = tag.Split(Constants.ReservedEqualsChar, 2);
string key = keyValue[0];
ValidateTagName(key);
string? value = keyValue.Length == 1 ? null : keyValue[1];
result.Tags.Add(key, value);
}
}

// Optional pipeline steps. The user can pass a custom list or leave it to the system to use the default.
if (form.TryGetValue(StepsField, out StringValues steps))
if (form.TryGetValue(Constants.WebServiceStepsField, out StringValues steps))
{
foreach (string? step in steps)
{
if (string.IsNullOrWhiteSpace(step)) { continue; }

// Allow step names to be separated by space, comma, semicolon
var list = step.Replace(' ', ';').Replace(',', ';').Split(';');
result.Steps.AddRange(from s in list where !string.IsNullOrWhiteSpace(s) select s.Trim());
}
}

result.DocumentId = documentId;
result.Index = indexes[0]!;
result.DocumentId = documentId;
result.Files = form.Files;

// Store any extra field as a tag
foreach (string key in form.Keys)
{
if (key == DocumentIdField
|| key == IndexField
|| key == StepsField
|| !form.TryGetValue(key, out StringValues values)) { continue; }

ValidateTagName(key);
foreach (string? x in values)
{
result.Tags.Add(key, x);
}
}

return (result, true, string.Empty);
}

private static void ValidateTagName(string key)
private static void ValidateTagName(string tagName)
{
if (key.Contains('=', StringComparison.Ordinal))
if (tagName.StartsWith(Constants.ReservedTagsPrefix, StringComparison.Ordinal))
{
throw new KernelMemoryException("A tag name cannot contain the '=' char");
throw new KernelMemoryException(
$"The tag name prefix '{Constants.ReservedTagsPrefix}' is reserved for internal use.");
}

if (key is
Constants.ReservedDocumentIdTag
if (tagName is Constants.ReservedDocumentIdTag
or Constants.ReservedFileIdTag
or Constants.ReservedFilePartitionTag
or Constants.ReservedFileTypeTag)
or Constants.ReservedFileTypeTag
or Constants.ReservedSyntheticTypeTag)
{
throw new KernelMemoryException($"The tag name '{key}' is reserved for internal use.");
throw new KernelMemoryException($"The tag name '{tagName}' is reserved for internal use.");
}
}
}
7 changes: 7 additions & 0 deletions service/Service/Service.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="0.24.231228.5" />
</ItemGroup>

<!-- <ItemGroup>-->
<!-- <ProjectReference Include="..\Core\Core.csproj"/>-->
<!-- <ProjectReference Include="..\..\extensions\LlamaSharp\LlamaSharp.csproj"/>-->
<!-- <ProjectReference Include="..\..\extensions\Postgres\Postgres\Postgres.csproj"/>-->
<!-- <ProjectReference Include="..\..\extensions\RabbitMQ\RabbitMQ.csproj"/>-->
<!-- </ItemGroup>-->

<!-- Code Analysis -->
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4">
Expand Down
34 changes: 34 additions & 0 deletions service/tests/ServiceFunctionalTests/DocumentUploadTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,38 @@ await this._memory.ImportDocumentAsync(
this.Log("Deleting memories extracted from the document");
await this._memory.DeleteDocumentAsync(Id);
}

[Fact]
[Trait("Category", "ServiceFunctionalTest")]
public async Task ItSupportTags()
{
// Arrange
const string Id = "ItSupportTags-file1-NASA-news.pdf";
await this._memory.ImportDocumentAsync(
"file1-NASA-news.pdf",
documentId: Id,
tags: new TagCollection
{
{ "type", "news" },
{ "type", "test" },
{ "ext", "pdf" }
},
steps: Constants.PipelineWithoutSummary);

// Act
var answer1 = await this._memory.AskAsync("What is Orion?", filter: MemoryFilters.ByTag("type", "news"));
this.Log(answer1.Result);
var answer2 = await this._memory.AskAsync("What is Orion?", filter: MemoryFilters.ByTag("type", "test"));
this.Log(answer2.Result);
var answer3 = await this._memory.AskAsync("What is Orion?", filter: MemoryFilters.ByTag("ext", "pdf"));
this.Log(answer3.Result);
var answer4 = await this._memory.AskAsync("What is Orion?", filter: MemoryFilters.ByTag("foo", "bar"));
this.Log(answer4.Result);

// Assert
Assert.Contains("spacecraft", answer1.Result, StringComparison.OrdinalIgnoreCase);
Assert.Contains("spacecraft", answer2.Result, StringComparison.OrdinalIgnoreCase);
Assert.Contains("spacecraft", answer3.Result, StringComparison.OrdinalIgnoreCase);
Assert.Contains("NOT FOUND", answer4.Result, StringComparison.OrdinalIgnoreCase);
}
}