-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcolorFeatures_RGB.py
63 lines (50 loc) · 1.88 KB
/
colorFeatures_RGB.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import numpy as np
import pandas as pd
from skimage import io
# Function to calculate average color (RGB) values
def calculate_average_color(image):
# Calculate average RGB values
avg_color = np.mean(image, axis=(0, 1))
return avg_color
# Function to load images and extract color features
def load_images_and_extract_color_features(folder):
features = [] # Features
labels = [] # Labels
for class_folder in os.listdir(folder):
if not os.path.isdir(os.path.join(folder, class_folder)):
continue # Skip if not a directory
for filename in os.listdir(os.path.join(folder, class_folder)):
if filename.endswith(".jpg") or filename.endswith(".png"):
image = io.imread(os.path.join(folder, class_folder, filename))
color_features = calculate_average_color(image)
features.append(color_features)
labels.append(class_folder) # Use folder name as label
return features, labels
# Modify the RGB data folder path
rgb_folder = "data/RGB data"
# Load images and extract color features for train and test folders
train_features, train_labels = load_images_and_extract_color_features(
os.path.join(rgb_folder, "train")
)
test_features, test_labels = load_images_and_extract_color_features(
os.path.join(rgb_folder, "test")
)
# Create DataFrames for train and test sets
train_df = pd.DataFrame(
train_features,
columns=["Avg_Red", "Avg_Green", "Avg_Blue"],
)
train_df["Label"] = train_labels
test_df = pd.DataFrame(
test_features,
columns=["Avg_Red", "Avg_Green", "Avg_Blue"],
)
test_df["Label"] = test_labels
# Save the DataFrames to CSV files
train_csv = "train_RGB.csv"
test_csv = "test_RGB.csv"
train_df.to_csv(train_csv, index=False)
test_df.to_csv(test_csv, index=False)
print("Train CSV saved to:", train_csv)
print("Test CSV saved to:", test_csv)