← Back to Articles

Machine Learning in JavaScript: TensorFlow.js Basics

CodeVideo
14 min read

Machine learning has traditionally been the domain of Python and specialized hardware. But with TensorFlow.js, you can now run machine learning models directly in the browser using JavaScript. This opens up exciting possibilities for intelligent web applications. In this comprehensive guide, we'll explore TensorFlow.js from basics to advanced implementations, covering everything you need to build intelligent web applications.

For a comprehensive video introduction to TensorFlow concepts that complement this guide, check out this complete TensorFlow 2.0 course:

TensorFlow 2.0 Complete Course

Note: This 7-hour course covers TensorFlow fundamentals in Python, while this article focuses on TensorFlow.js for JavaScript/browser-based machine learning.

Why TensorFlow.js?

Browser-Based ML Benefits:

  • No server required: Run models entirely in the browser
  • Privacy-focused: Data stays on user's device
  • Real-time processing: Instant results without network latency
  • Offline capable: Works without internet connection
  • JavaScript ecosystem: Leverages existing web development skills

Use Cases:

  • Image recognition: Classify images, detect objects
  • Natural language processing: Sentiment analysis, text generation
  • Audio processing: Speech recognition, music generation
  • Gesture recognition: Hand tracking, pose estimation
  • Recommendation systems: Personalized content suggestions

Getting Started

1. Installation

# Using npm
npm install @tensorflow/tfjs

# Using yarn
yarn add @tensorflow/tfjs

# For specific backends
npm install @tensorflow/tfjs-node  # Node.js backend
npm install @tensorflow/tfjs-backend-wasm  # WebAssembly backend

2. Basic Setup

<!DOCTYPE html>
<html>
<head>
  <title>TensorFlow.js App</title>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
</head>
<body>
  <h1>TensorFlow.js Demo</h1>
  <div id="output"></div>
  <script src="app.js"></script>
</body>
</html>

3. First TensorFlow.js Program

// app.js
async function run() {
  // Create a simple tensor
  const tensor = tf.tensor([1, 2, 3, 4]);
  console.log('Tensor:', tensor);

  // Perform operations
  const squared = tensor.square();
  console.log('Squared:', squared);

  // Clean up memory
  tensor.dispose();
  squared.dispose();

  document.getElementById('output').innerHTML = 'TensorFlow.js is working!';
}

run();

Core Concepts

1. Tensors: The Building Blocks

Tensors are the fundamental data structure in TensorFlow.js:

// Different ways to create tensors
const scalar = tf.scalar(3.14);                    // 0D tensor
const vector = tf.tensor([1, 2, 3]);               // 1D tensor
const matrix = tf.tensor([[1, 2], [3, 4]]);        // 2D tensor
const cube = tf.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); // 3D tensor

// Create tensors with specific shapes
const zeros = tf.zeros([2, 3]);    // 2x3 matrix of zeros
const ones = tf.ones([3, 3]);      // 3x3 matrix of ones
const random = tf.randomNormal([2, 2]); // 2x2 matrix with random values

// Tensor properties
console.log('Shape:', tensor.shape);       // [3]
console.log('DataType:', tensor.dtype);    // 'float32'
console.log('Rank:', tensor.rank);         // 1

2. Operations and Transformations

const a = tf.tensor([1, 2, 3, 4]);
const b = tf.tensor([5, 6, 7, 8]);

// Element-wise operations
const sum = a.add(b);           // [6, 8, 10, 12]
const product = a.mul(b);       // [5, 12, 21, 32]
const square = a.square();      // [1, 4, 9, 16]

// Matrix operations
const matrixA = tf.tensor([[1, 2], [3, 4]]);
const matrixB = tf.tensor([[5, 6], [7, 8]]);
const matrixProduct = matrixA.matMul(matrixB);

// Reduction operations
const sumAll = a.sum();         // Sum of all elements
const mean = a.mean();          // Mean of all elements
const max = a.max();            // Maximum value
const argMax = a.argMax();      // Index of maximum value

// Reshaping
const reshaped = a.reshape([2, 2]); // Reshape to 2x2 matrix

3. Memory Management

// Always dispose of tensors to prevent memory leaks
function processData() {
  const data = tf.tensor([1, 2, 3, 4, 5, 6]);
  const processed = data.add(10).square();

  // Use the result
  console.log(processed.dataSync());

  // Clean up
  data.dispose();
  processed.dispose();
}

// Or use tidy for automatic cleanup
const result = tf.tidy(() => {
  const data = tf.tensor([1, 2, 3, 4, 5, 6]);
  return data.add(10).square();
});
// result is automatically disposed when tidy completes

Building Neural Networks

1. Sequential API (Simple Networks)

// Create a simple sequential model
const model = tf.sequential();

// Add layers
model.add(tf.layers.dense({inputShape: [4], units: 10, activation: 'relu'}));
model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

// Compile the model
model.compile({
  optimizer: 'adam',
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

// Summary
model.summary();

2. Layers API

// Dense (fully connected) layer
const denseLayer = tf.layers.dense({
  units: 64,
  activation: 'relu',
  inputShape: [784]  // For MNIST images (28x28 = 784)
});

// Convolutional layer for images
const convLayer = tf.layers.conv2d({
  filters: 32,
  kernelSize: [3, 3],
  activation: 'relu',
  inputShape: [28, 28, 1]  // Height, width, channels
});

// Recurrent layer for sequences
const lstmLayer = tf.layers.lstm({
  units: 128,
  returnSequences: false,
  inputShape: [100, 50]  // Sequence length, feature dimension
});

// Dropout for regularization
const dropoutLayer = tf.layers.dropout({
  rate: 0.2
});

// Batch normalization
const batchNormLayer = tf.layers.batchNormalization();

3. Custom Layers

class CustomLayer extends tf.layers.Layer {
  constructor(config) {
    super(config);
  }

  build(inputShape) {
    // Create layer weights
    this.kernel = this.addWeight('kernel', [inputShape[1], this.units], 'float32');
    this.bias = this.addWeight('bias', [this.units], 'float32');
  }

  call(input) {
    // Forward pass logic
    return tf.add(tf.matMul(input, this.kernel), this.bias);
  }

  computeOutputShape(inputShape) {
    return [inputShape[0], this.units];
  }

  getConfig() {
    const config = super.getConfig();
    return config;
  }

  static get className() {
    return 'CustomLayer';
  }
}

Training Models

1. Data Preparation

// Convert data to tensors
const xs = tf.tensor2d([[1, 2], [3, 4], [5, 6]], [3, 2]); // Features
const ys = tf.tensor2d([[0, 1], [1, 0], [0, 1]], [3, 2]); // Labels (one-hot encoded)

// Normalize data
const normalizedXs = xs.div(tf.scalar(10));

// Split into train/test sets
const splitIndex = Math.floor(xs.shape[0] * 0.8);
const trainXs = xs.slice([0, 0], [splitIndex, -1]);
const testXs = xs.slice([splitIndex, 0], [-1, -1]);

2. Training Loop

async function trainModel(model, trainData, epochs = 100) {
  const history = [];

  for (let epoch = 0; epoch < epochs; epoch++) {
    const result = await model.fit(trainData.xs, trainData.ys, {
      epochs: 1,
      batchSize: 32,
      validationSplit: 0.2,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          history.push(logs);
          console.log(`Epoch ${epoch + 1}: loss = ${logs.loss.toFixed(4)}, accuracy = ${logs.acc.toFixed(4)}`);
        }
      }
    });
  }

  return history;
}

3. Custom Training Loops

async function customTraining(model, trainData, optimizer) {
  const lossFn = tf.losses.meanSquaredError;

  for (let epoch = 0; epoch < 100; epoch++) {
    const loss = tf.tidy(() => {
      // Forward pass
      const predictions = model.predict(trainData.xs);
      const loss = lossFn(trainData.ys, predictions);

      // Backward pass
      const grads = tf.grads(() => lossFn(trainData.ys, model.predict(trainData.xs)));
      const gradients = grads(model.trainableWeights);

      // Update weights
      optimizer.applyGradients(zip(model.trainableWeights, gradients));

      return loss;
    });

    console.log(`Epoch ${epoch + 1}, Loss: ${loss.dataSync()[0]}`);
    loss.dispose();
  }
}

Pre-trained Models

1. Loading Models

// Load MobileNet for image classification
const mobilenet = await tf.loadLayersModel('https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/classification/5/default/1');

// Load COCO-SSD for object detection
const cocoSsd = await tf.loadGraphModel('https://tfhub.dev/tensorflow/tfjs-model/ssd_mobilenet_v2/1/default/1');

2. Using Pre-trained Models

// Image classification with MobileNet
async function classifyImage(imageElement) {
  // Preprocess image
  const tfImage = tf.browser.fromPixels(imageElement);
  const resized = tf.image.resizeBilinear(tfImage, [224, 224]);
  const normalized = resized.div(tf.scalar(255));
  const batched = normalized.expandDims(0);

  // Make prediction
  const predictions = await mobilenet.predict(batched);
  const topPrediction = predictions.argMax(-1).dataSync()[0];

  // Clean up
  tfImage.dispose();
  resized.dispose();
  normalized.dispose();
  batched.dispose();

  return topPrediction;
}

3. Model Conversion

Convert models from Python TensorFlow to TensorFlow.js:

# Install TensorFlow.js converter
pip install tensorflowjs

# Convert SavedModel to TensorFlow.js format
tensorflowjs_converter \
  --input_format=tf_saved_model \
  --output_format=tfjs_graph_model \
  --saved_model_tags=serve \
  input_model/ \
  output_model/

Real-World Applications

1. Image Classification App

class ImageClassifier {
  constructor() {
    this.model = null;
    this.labels = ['Cat', 'Dog', 'Bird', 'Fish'];
  }

  async loadModel() {
    // Load a pre-trained model or create custom one
    this.model = await tf.loadLayersModel('/models/image-classifier/model.json');
  }

  async classifyImage(imageData) {
    const tensor = tf.browser.fromPixels(imageData)
      .resizeNearestNeighbor([224, 224])
      .toFloat()
      .div(tf.scalar(255))
      .expandDims();

    const predictions = await this.model.predict(tensor);
    const predictedClass = predictions.argMax(-1).dataSync()[0];

    tensor.dispose();
    predictions.dispose();

    return this.labels[predictedClass];
  }
}

// Usage
const classifier = new ImageClassifier();
await classifier.loadModel();

// In your UI
document.getElementById('imageInput').addEventListener('change', async (e) => {
  const file = e.target.files[0];
  const img = new Image();
  img.onload = async () => {
    const result = await classifier.classifyImage(img);
    document.getElementById('result').textContent = `Prediction: ${result}`;
  };
  img.src = URL.createObjectURL(file);
});

2. Real-time Object Detection

class ObjectDetector {
  constructor() {
    this.model = null;
  }

  async loadModel() {
    this.model = await tf.loadGraphModel('/models/coco-ssd/model.json');
  }

  async detectObjects(videoElement) {
    const tfImage = tf.browser.fromPixels(videoElement);
    const resized = tf.image.resizeBilinear(tfImage, [300, 300]);
    const normalized = resized.div(tf.scalar(255));
    const batched = normalized.expandDims(0);

    const predictions = await this.model.executeAsync(batched);

    // Process predictions
    const boxes = predictions[0].dataSync();
    const scores = predictions[1].dataSync();
    const classes = predictions[2].dataSync();

    // Filter high-confidence detections
    const threshold = 0.5;
    const detections = [];
    for (let i = 0; i < scores.length; i++) {
      if (scores[i] > threshold) {
        detections.push({
          box: [boxes[i*4], boxes[i*4+1], boxes[i*4+2], boxes[i*4+3]],
          score: scores[i],
          class: classes[i]
        });
      }
    }

    // Clean up
    tfImage.dispose();
    resized.dispose();
    normalized.dispose();
    batched.dispose();

    return detections;
  }
}

3. Natural Language Processing

class SentimentAnalyzer {
  constructor() {
    this.model = null;
    this.tokenizer = null;
  }

  async loadModel() {
    // Load Universal Sentence Encoder
    this.model = await tf.loadLayersModel('/models/use/model.json');

    // Load tokenizer (simplified example)
    this.tokenizer = {
      encode: (text) => {
        // Tokenize text into word indices
        return text.toLowerCase().split(' ').map(word => this.wordToIndex(word));
      }
    };
  }

  async analyzeSentiment(text) {
    const tokens = this.tokenizer.encode(text);
    const tensor = tf.tensor([tokens]);

    const embeddings = await this.model.predict(tensor);
    const sentiment = await this.classifySentiment(embeddings);

    tensor.dispose();
    embeddings.dispose();

    return sentiment;
  }

  async classifySentiment(embeddings) {
    // Simple classification logic (replace with actual model)
    const mean = embeddings.mean().dataSync()[0];
    return mean > 0.5 ? 'positive' : 'negative';
  }
}

Performance Optimization

1. Backend Selection

// Choose the best available backend
async function setupBackend() {
  // Try WebGL first (fastest)
  if (tf.findBackend('webgl')) {
    await tf.setBackend('webgl');
    console.log('Using WebGL backend');
  }
  // Fallback to CPU
  else if (tf.findBackend('cpu')) {
    await tf.setBackend('cpu');
    console.log('Using CPU backend');
  }
  // Fallback to WASM
  else if (tf.findBackend('wasm')) {
    await tf.setBackend('wasm');
    console.log('Using WebAssembly backend');
  }

  await tf.ready();
}

2. Model Optimization

// Quantize model for smaller size and faster inference
async function optimizeModel(model) {
  // Convert weights to lower precision
  const quantizedModel = await tf.loadLayersModel(modelPath, {
    weightDataType: 'uint8',  // or 'uint16'
    weightQuantization: true
  });

  return quantizedModel;
}

// Use model.predict() with tf.tidy for memory efficiency
function efficientPrediction(model, input) {
  return tf.tidy(() => {
    const prediction = model.predict(input);
    return prediction.squeeze(); // Remove unnecessary dimensions
  });
}

3. Memory Management Best Practices

class MemoryManager {
  static tensors = new Set();

  static track(tensor) {
    this.tensors.add(tensor);
    return tensor;
  }

  static dispose(tensor) {
    if (tensor && !tensor.isDisposed) {
      tensor.dispose();
      this.tensors.delete(tensor);
    }
  }

  static disposeAll() {
    this.tensors.forEach(tensor => {
      if (!tensor.isDisposed) {
        tensor.dispose();
      }
    });
    this.tensors.clear();
  }
}

// Usage
const result = MemoryManager.track(model.predict(input));
// ... use result ...
MemoryManager.dispose(result);

Deployment and Production

1. Model Serving

// Serve models from CDN
const MODEL_CONFIG = {
  mobilenet: {
    url: 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json',
    size: '16MB'
  },
  coco: {
    url: 'https://storage.googleapis.com/tfjs-models/tfjs/ssd_mobilenet_v1/model.json',
    size: '27MB'
  }
};

class ModelLoader {
  static cache = new Map();

  static async load(modelName) {
    if (this.cache.has(modelName)) {
      return this.cache.get(modelName);
    }

    const config = MODEL_CONFIG[modelName];
    if (!config) {
      throw new Error(`Model ${modelName} not found`);
    }

    console.log(`Loading model: ${modelName} (${config.size})`);
    const model = await tf.loadLayersModel(config.url);

    this.cache.set(modelName, model);
    return model;
  }
}

2. Progressive Loading

class ProgressiveModelLoader {
  constructor(modelUrls) {
    this.modelUrls = modelUrls;
    this.loadedModels = new Map();
  }

  async loadEssential() {
    // Load critical models first
    const essential = ['mobilenet'];
    for (const modelName of essential) {
      this.loadedModels.set(modelName, await tf.loadLayersModel(this.modelUrls[modelName]));
    }
  }

  async loadOptional() {
    // Load additional models in background
    const optional = ['coco-ssd', 'posenet'];
    for (const modelName of optional) {
      try {
        const model = await tf.loadLayersModel(this.modelUrls[modelName]);
        this.loadedModels.set(modelName, model);
      } catch (error) {
        console.warn(`Failed to load optional model: ${modelName}`, error);
      }
    }
  }

  getModel(name) {
    return this.loadedModels.get(name);
  }
}

3. Error Handling and Fallbacks

class RobustModelRunner {
  constructor() {
    this.models = new Map();
    this.fallbacks = new Map();
  }

  async loadWithFallback(modelName, primaryUrl, fallbackUrl) {
    try {
      const model = await tf.loadLayersModel(primaryUrl);
      this.models.set(modelName, model);
    } catch (primaryError) {
      console.warn(`Primary model failed, trying fallback`, primaryError);
      try {
        const fallbackModel = await tf.loadLayersModel(fallbackUrl);
        this.models.set(modelName, fallbackModel);
        this.fallbacks.set(modelName, true);
      } catch (fallbackError) {
        console.error(`Both primary and fallback models failed`, fallbackError);
        throw new Error(`Failed to load model: ${modelName}`);
      }
    }
  }

  async runModel(modelName, input) {
    const model = this.models.get(modelName);
    if (!model) {
      throw new Error(`Model ${modelName} not loaded`);
    }

    try {
      const result = await model.predict(input);
      return result;
    } catch (error) {
      console.error(`Model execution failed: ${modelName}`, error);
      // Implement retry logic or use cached results
      throw error;
    }
  }
}

Testing and Debugging

1. Unit Testing

const tf = require('@tensorflow/tfjs-node');

describe('Neural Network Tests', () => {
  let model;

  beforeEach(() => {
    model = tf.sequential();
    model.add(tf.layers.dense({inputShape: [4], units: 10}));
    model.add(tf.layers.dense({units: 1}));
    model.compile({optimizer: 'adam', loss: 'meanSquaredError'});
  });

  test('model makes predictions', async () => {
    const input = tf.tensor([[1, 2, 3, 4]]);
    const prediction = model.predict(input);

    expect(prediction.shape).toEqual([1, 1]);
    expect(typeof prediction.dataSync()[0]).toBe('number');

    input.dispose();
    prediction.dispose();
  });

  test('model trains successfully', async () => {
    const xs = tf.randomNormal([100, 4]);
    const ys = tf.randomNormal([100, 1]);

    const history = await model.fit(xs, ys, {
      epochs: 1,
      verbose: 0
    });

    expect(history.history.loss.length).toBe(1);

    xs.dispose();
    ys.dispose();
  });
});

2. Performance Monitoring

class PerformanceMonitor {
  static timings = new Map();

  static start(label) {
    this.timings.set(label, performance.now());
  }

  static end(label) {
    const start = this.timings.get(label);
    if (start) {
      const duration = performance.now() - start;
      console.log(`${label}: ${duration.toFixed(2)}ms`);
      this.timings.delete(label);
      return duration;
    }
    return 0;
  }

  static async measureAsync(label, asyncFn) {
    this.start(label);
    const result = await asyncFn();
    this.end(label);
    return result;
  }
}

// Usage
const result = await PerformanceMonitor.measureAsync('model prediction', async () => {
  return await model.predict(input);
});

Advanced Topics

1. Transfer Learning

async function createTransferModel(baseModel, newOutputUnits) {
  // Freeze base model layers
  for (const layer of baseModel.layers) {
    layer.trainable = false;
  }

  // Add new classification head
  const transferModel = tf.sequential({
    layers: [
      ...baseModel.layers,
      tf.layers.globalAveragePooling2d(),
      tf.layers.dense({units: 128, activation: 'relu'}),
      tf.layers.dropout({rate: 0.5}),
      tf.layers.dense({units: newOutputUnits, activation: 'softmax'})
    ]
  });

  return transferModel;
}

2. Federated Learning

class FederatedLearner {
  constructor(model) {
    this.model = model;
    this.clientModels = [];
  }

  addClientModel(clientModel) {
    this.clientModels.push(clientModel);
  }

  async aggregateModels() {
    // Simple averaging of model weights
    const aggregatedWeights = [];

    for (let i = 0; i < this.model.weights.length; i++) {
      const layerWeights = this.clientModels.map(client =>
        client.weights[i].read()
      );

      // Average the weights
      const avgWeight = tf.tidy(() => {
        const stacked = tf.stack(layerWeights);
        return stacked.mean(0);
      });

      aggregatedWeights.push(avgWeight);
    }

    // Update global model
    for (let i = 0; i < aggregatedWeights.length; i++) {
      this.model.weights[i].write(aggregatedWeights[i]);
      aggregatedWeights[i].dispose();
    }
  }
}

3. Model Compression

async function compressModel(model) {
  // Prune low-weight connections
  const prunedModel = await pruneModel(model, 0.1); // Remove 10% of weights

  // Quantize weights
  const quantizedModel = await quantizeModel(prunedModel, 'uint8');

  // Apply knowledge distillation
  const distilledModel = await distillModel(quantizedModel, teacherModel);

  return distilledModel;
}

Future of TensorFlow.js

1. Emerging Features

  • Web Neural Network API: Hardware-accelerated inference
  • WebGPU: Next-generation graphics API support
  • WebAssembly SIMD: Vectorized operations
  • Edge Computing: On-device ML inference

2. Industry Adoption

  • Smart Cameras: Real-time object detection
  • Voice Assistants: On-device speech recognition
  • Medical Imaging: Client-side image analysis
  • Augmented Reality: Real-time pose estimation
  • Autonomous Vehicles: Sensor data processing

3. Performance Improvements

  • Better Memory Management: Automatic tensor disposal
  • Graph Optimization: Improved execution planning
  • Hardware Acceleration: Better GPU utilization
  • Model Quantization: Smaller, faster models

Best Practices

1. Development Workflow

  • Start Simple: Begin with pre-trained models
  • Iterate Quickly: Use tfjs-vis for debugging
  • Test Thoroughly: Unit tests for all ML logic
  • Monitor Performance: Track memory usage and inference time

2. Production Considerations

  • Model Size: Compress models for web deployment
  • Fallback Strategies: Graceful degradation when ML fails
  • Privacy: Keep user data on-device
  • Accessibility: Provide alternatives to ML features

3. Learning Resources

  • Official Documentation: tensorflow.org/js
  • TensorFlow.js Examples: github.com/tensorflow/tfjs-examples
  • Made with TensorFlow.js: Showcases of real applications
  • TensorFlow Hub: Pre-trained models repository

Conclusion

TensorFlow.js brings the power of machine learning directly to the browser, enabling intelligent web applications that respect user privacy and work offline. From simple image classification to complex natural language processing, the possibilities are vast.

Key Takeaways:

  • Browser-first ML: Run models entirely in the browser
  • Privacy-focused: Data stays on the user's device
  • JavaScript ecosystem: Leverage existing web skills
  • Progressive enhancement: ML as enhancement, not requirement
  • Performance optimization: Choose the right backend and optimize models

The future of web development includes intelligent applications that can see, hear, understand, and learn. TensorFlow.js makes this future accessible to every JavaScript developer. Start experimenting with pre-trained models, gradually build your understanding, and create the next generation of intelligent web applications.

About the author

Rafael De Paz

Full Stack Developer

Passionate full-stack developer specializing in building high-quality web applications and responsive sites. Expert in robust data handling, leveraging modern frameworks, cloud technologies, and AI tools to deliver scalable, high-performance solutions that drive user engagement and business growth. I harness AI technologies to accelerate development, testing, and debugging workflows.

Tags:

Share: