Skip to content

Commit

Permalink
Entity classification updates (#281)
Browse files Browse the repository at this point in the history
* node classification updates

* update

* remove unused code

* update
  • Loading branch information
prasmussen15 authored Feb 27, 2025
1 parent 1d2417e commit 6f87473
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
11 changes: 8 additions & 3 deletions graphiti_core/prompts/extract_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ class MissedEntities(BaseModel):


class EntityClassification(BaseModel):
entity_classification: str = Field(
entities: list[str] = Field(
...,
description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
description='List of entities',
)
entity_classifications: list[str | None] = Field(
...,
description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.',
)


Expand Down Expand Up @@ -180,7 +184,8 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. Each entity must have exactly one type
2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
3. If none of the provided entity types accurately classify an extracted node, the type should be set to None
"""
return [
Message(role='system', content=sys_prompt),
Expand Down
10 changes: 6 additions & 4 deletions graphiti_core/utils/maintenance/node_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

import ast
import logging
from time import time

Expand Down Expand Up @@ -163,8 +162,9 @@ async def extract_nodes(
prompt_library.extract_nodes.classify_nodes(node_classification_context),
response_model=EntityClassification,
)
response_string = llm_response.get('entity_classification', '{}')
node_classifications.update(ast.literal_eval(response_string))
entities = llm_response.get('entities', [])
entity_classifications = llm_response.get('entity_classifications', [])
node_classifications.update(dict(zip(entities, entity_classifications)))

end = time()
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
Expand All @@ -173,7 +173,9 @@ async def extract_nodes(
for name in extracted_node_names:
entity_type = node_classifications.get(name)
labels = (
['Entity'] if entity_type is None or entity_type == 'None' else ['Entity', entity_type]
['Entity']
if entity_type is None or entity_type == 'None' or entity_type == 'null'
else ['Entity', entity_type]
)

new_node = EntityNode(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.7.4"
version = "0.7.5"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <[email protected]>",
Expand Down

0 comments on commit 6f87473

Please sign in to comment.