Skip to content

Commit

Permalink
image classification #23
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Dec 19, 2020
1 parent 43c06d5 commit 0316fde
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 43 deletions.
21 changes: 15 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SciSharp STACK Examples
This repo contains many practical examples written in SciSharp's machine learning libraries. If you still don't know how to use .NET for deep learning, getting started from here is your best choice.
This repo contains many practical examples written in SciSharp's machine learning libraries. If you still don't know how to use .NET for deep learning, getting started from these examples is your best choice.

[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)

Expand Down Expand Up @@ -30,6 +30,8 @@ dotnet TensorFlowNET.Examples.dll -ex "MNIST CNN (Eager)"

Example runner will download all the required files like training data and model pb files.

#### Basic Model

* Hello World [C#](src/TensorFlowNET.Examples/HelloWorld.cs)
* Basic Operations [C#](src/TensorFlowNET.Examples/BasicOperations.cs)
* Linear Regression in Graph mode [C#](src/TensorFlowNET.Examples/BasicModels/LinearRegression.cs)
Expand All @@ -38,22 +40,29 @@ Example runner will download all the required files like training data and model
* Logistic Regression in Eager mode [C#](src/TensorFlowNET.Examples/BasicModels/LogisticRegressionEager.cs)
* Nearest Neighbor [C#](src/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs)
* Naive Bayes Classification [C#](src/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs)
* Full Connected Neural Network in Eager mode [C#](src/TensorFlowNET.Examples/\NeuralNetworks/FullyConnectedEager.cs)
* K-means Clustering [C#](src/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs)

#### Neural Network

* Full Connected Neural Network in Eager mode [C#](src/TensorFlowNET.Examples/\NeuralNetworks/FullyConnectedEager.cs)
* NN XOR [C#](src/TensorFlowNET.Examples/NeuralNetworks/NeuralNetXor.cs)
* Object Detection in MobileNet [C#](src/TensorFlowNET.Examples/ObjectDetection/DetectInMobilenet.cs)
* Binary Text Classification [C#](src/TensorFlowNET.Examples/TextProcessing/BinaryTextClassification.cs)
* CNN Text Classification [C#](src/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs)
* MNIST FNN in Keras Functional API [C#](src/TensorFlowNET.Examples/ImageProcessing/MnistFnnKerasFunctional.cs)
* MNIST CNN in Graph mode [C#](src/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs)
* MNIST RNN [C#](src/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs)
* MNIST LSTM [C#](src/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionLSTM.cs)
* Named Entity Recognition [C#](src/TensorFlowNET.Examples/TextProcessing/NER)
* Image Classification in Keras Sequential API [C#](src/TensorFlowNET.Examples/ImageProcessing/ImageClassificationKeras.cs)
* Toy ResNet in Keras Functional API [C#](src/TensorFlowNET.Examples/ImageProcessing/ToyResNet.cs)
* Transfer Learning for Image Classification in InceptionV3 [C#](src/TensorFlowNET.Examples/ImageProcessing/TransferLearningWithInceptionV3.cs)
* CNN In Your Own Dataset [C#](src/TensorFlowNET.Examples/ImageProcessing/CnnInYourOwnData.cs)

#### Natural Language Processing
* Binary Text Classification [C#](src/TensorFlowNET.Examples/TextProcessing/BinaryTextClassification.cs)
* CNN Text Classification [C#](src/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs)
* Named Entity Recognition [C#](src/TensorFlowNET.Examples/TextProcessing/NER)


### Welcome to PR your example to us.
#### Welcome to PR your example to us.
Your contribution will make .NET community better than ever.
<br>
<a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp-stack.png" width="391" height="100" /></a>
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
using Tensorflow.Keras;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System.Linq;
using Tensorflow.Keras.Utils;
using System.IO;
using Tensorflow.Keras.Engine;

namespace TensorFlowNET.Examples
{
Expand All @@ -12,67 +16,100 @@ namespace TensorFlowNET.Examples
/// </summary>
public class ImageClassificationKeras : SciSharpExample, IExample
{
int batch_size = 32;
int epochs = 10;
TensorShape img_dim = (180, 180);
IDatasetV2 train_ds, val_ds;
Model model;

public ExampleConfig InitConfig()
=> Config = new ExampleConfig
{
Name = "Image Classification (Keras)",
Enabled = false,
Enabled = true,
Priority = 18
};

public bool Run()
{
tf.enable_eager_execution();

PrepareData();
BuildModel();
Train();

return true;
}

public override void BuildModel()
{
int num_classes = 5;
// var normalization_layer = tf.keras.layers.Rescaling(1.0f / 255);
var layers = keras.layers;
model = keras.Sequential(new List<ILayer>
{
layers.Rescaling(1.0f / 255, input_shape: (img_dim.dims[0], img_dim.dims[1], 3)),
layers.Conv2D(16, 3, padding: "same", activation: keras.activations.Relu),
layers.MaxPooling2D(),
/*layers.Conv2D(32, 3, padding: "same", activation: "relu"),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding: "same", activation: "relu"),
layers.MaxPooling2D(),*/
layers.Flatten(),
layers.Dense(128, activation: keras.activations.Relu),
layers.Dense(num_classes)
});

model.compile(optimizer: keras.optimizers.Adam(),
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
metrics: new[] { "accuracy" });

model.summary();
}

public override void Train()
{
model.fit(train_ds, validation_data: val_ds, epochs: epochs);
}

public override void PrepareData()
{
int batch_size = 32;
TensorShape img_dim = (180, 180);
string fileName = "flower_photos.tgz";
string url = $"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz";
string data_dir = Path.GetTempPath();
Web.Download(url, data_dir, fileName);
Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir);
data_dir = Path.Combine(data_dir, "flower_photos");

var data_dir = @"C:/Users/haipi/.keras/datasets/flower_photos";
var train_ds = keras.preprocessing.image_dataset_from_directory(data_dir,
// convert to tensor
train_ds = keras.preprocessing.image_dataset_from_directory(data_dir,
validation_split: 0.2f,
subset: "training",
seed: 123,
image_size: img_dim,
batch_size: batch_size);

var val_ds = keras.preprocessing.image_dataset_from_directory(data_dir,
val_ds = keras.preprocessing.image_dataset_from_directory(data_dir,
validation_split: 0.2f,
subset: "validation",
seed: 123,
image_size: img_dim,
batch_size: batch_size);

train_ds = train_ds.cache().shuffle(100).prefetch(buffer_size: -1);
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size: -1);
val_ds = val_ds.cache().prefetch(buffer_size: -1);

foreach (var (img, label) in train_ds)
{
print("batch images: " + img.TensorShape);
print("labels: " + label);
print($"images: {img.TensorShape}");
var nd = label.numpy();
print($"labels: {nd}");
var data = nd.Data<int>();
if (data.Max() > 4 || data.Min() < 0)
{
// exception
}
}

int num_classes = 5;
// var normalization_layer = tf.keras.layers.Rescaling(1.0f / 255);
var layers = keras.layers;
var model = keras.Sequential(new List<ILayer>
{
layers.Rescaling(1.0f / 255, input_shape: (img_dim.dims[0], img_dim.dims[1], 3)),
layers.Conv2D(16, 3, padding: "same", activation: keras.activations.Relu),
layers.MaxPooling2D(),
/*layers.Conv2D(32, 3, padding: "same", activation: "relu"),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding: "same", activation: "relu"),
layers.MaxPooling2D(),*/
layers.Flatten(),
layers.Dense(128, activation: keras.activations.Relu),
layers.Dense(num_classes)
});

model.compile("adam", keras.losses.SparseCategoricalCrossentropy(from_logits: true));
}
}
}
7 changes: 3 additions & 4 deletions src/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DefineConstants>DEBUG;TRACE</DefineConstants>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
<PlatformTarget>AnyCPU</PlatformTarget>
</PropertyGroup>

Expand Down Expand Up @@ -38,13 +38,12 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Colorful.Console" Version="1.2.11" />
<PackageReference Include="Colorful.Console" Version="1.2.15" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="OpenCvSharp4.runtime.win" Version="4.4.0.20200915" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" />
<PackageReference Include="SharpCV" Version="0.6.0" />
<PackageReference Include="System.Drawing.Common" Version="4.7.0" />
<PackageReference Include="TensorFlow.Keras" Version="0.2.1" />
<PackageReference Include="System.Drawing.Common" Version="5.0.0" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 2 additions & 0 deletions src/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public ExampleConfig InitConfig()

public bool Run()
{
tf.compat.v1.disable_eager_execution();

PrepareData();

var graph = tf.Graph().as_default();
Expand Down
17 changes: 17 additions & 0 deletions src/tensorflow2.x-python-tutorial/.vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/image_classification.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"justMyCode": false
}
]
}
7 changes: 5 additions & 2 deletions src/tensorflow2.x-python-tutorial/MapDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

# Import TensorFlow v2.
import tensorflow as tf
from tensorflow.python.data.experimental.ops import cardinality

tf.autograph.set_verbosity(10, alsologtostdout=True)
tf.config.run_functions_eagerly(True)

def add(x):
print("debug test test")
return x + 1

dataset = tf.data.Dataset.range(1, 3) # ==> [ 1, 2, 3, 4, 5 ]
dataset = tf.data.Dataset.range(10) # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(add)
card = cardinality.cardinality(dataset)
for item in dataset:
print(item)
print(item)
19 changes: 16 additions & 3 deletions src/tensorflow2.x-python-tutorial/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from tensorflow.keras.models import Sequential
from tensorflow.python.ops import bitwise_ops

"""
y_true = np.array([[0., 1.], [0., 0.]],dtype='float64')
y_pred = np.array([[1., 1.], [1., 0.]],dtype='float64')
mse = tf.keras.losses.MeanSquaredError()
print(mse(y_true, y_pred).numpy())
"""

import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
Expand All @@ -24,14 +31,21 @@
img_height = 180
img_width = 180

tf.config.run_functions_eagerly(True)

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
shuffle=False,
image_size=(img_height, img_width),
batch_size=batch_size)

for img, label in train_ds:
print("batch images: ", img.shape)
print("labels: ", label)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
Expand All @@ -47,7 +61,7 @@

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
Expand Down Expand Up @@ -82,8 +96,7 @@
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
epochs=epochs)

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
Expand Down
12 changes: 12 additions & 0 deletions src/tensorflow2.x-python-tutorial/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np
import tensorflow as tf
import tensorflow_text as text

docs = tf.constant([u'Everything not saved will be lost.'])
tokenizer = text.WhitespaceTokenizer()
tokens = tokenizer.tokenize(docs)
f1 = text.wordshape(tokens, text.WordShape.HAS_TITLE_CASE)
bigrams = text.ngrams(tokens, 2, reduction_type=text.Reduction.STRING_JOIN)

print(docs)
a = input()

0 comments on commit 0316fde

Please sign in to comment.