165 lines
5.6 KiB
TypeScript
165 lines
5.6 KiB
TypeScript
import React, { useState } from "react";
|
||
import { Sliders } from "lucide-react";
|
||
|
||
export const MatrixViz: React.FC = () => {
|
||
const [rank, setRank] = useState(4);
|
||
|
||
// Constants for visualization scaling
|
||
const d = 16; // Hidden dimension simulation
|
||
const k = 16; // Output dimension simulation
|
||
|
||
// Calculate imaginary parameters based on user slider
|
||
// Assuming a real model like Llama-7B, d_model might be 4096.
|
||
// We scale the display numbers to look realistic.
|
||
const realD = 4096;
|
||
const realK = 4096;
|
||
|
||
const paramsFull = realD * realK;
|
||
const paramsLoRA = realD * rank + rank * realK;
|
||
const reduction = ((1 - paramsLoRA / paramsFull) * 100).toFixed(2);
|
||
|
||
// Helper to generate random opacity blue blocks
|
||
const renderGrid = (
|
||
rows: number,
|
||
cols: number,
|
||
colorBase: string,
|
||
label: string,
|
||
) => {
|
||
return (
|
||
<div
|
||
className="relative transition-all duration-300 ease-in-out"
|
||
style={{
|
||
display: "grid",
|
||
gridTemplateColumns: `repeat(${cols}, minmax(0, 1fr))`,
|
||
gridTemplateRows: `repeat(${rows}, minmax(0, 1fr))`,
|
||
width: `${cols * 6}px`,
|
||
height: `${rows * 6}px`,
|
||
border: "1px solid #e5e7eb",
|
||
backgroundColor: "#f9fafb",
|
||
}}
|
||
>
|
||
{Array.from({ length: rows * cols }).map((_, i) => (
|
||
<div
|
||
key={i}
|
||
style={{
|
||
backgroundColor: colorBase,
|
||
opacity: Math.random() * 0.6 + 0.2,
|
||
}}
|
||
/>
|
||
))}
|
||
|
||
{/* Matrix Label Overlay */}
|
||
<div className="absolute -bottom-6 left-0 w-full text-center text-[10px] font-mono text-gray-500">
|
||
{label}
|
||
</div>
|
||
</div>
|
||
);
|
||
};
|
||
|
||
return (
|
||
<div className="my-12 font-sans bg-gray-50 border border-gray-200 rounded-lg p-8">
|
||
<div className="flex flex-col md:flex-row items-start md:items-center justify-between mb-8 gap-6">
|
||
<div>
|
||
<h4 className="font-bold text-sm uppercase tracking-widest text-ink flex items-center gap-2">
|
||
<Sliders size={16} />
|
||
Interactive Decomposition
|
||
</h4>
|
||
<p className="text-xs text-subtle mt-1 max-w-md">
|
||
Adjust the Rank (r) to see how LoRA decomposes the weight update
|
||
matrix into two smaller dense matrices.
|
||
</p>
|
||
</div>
|
||
|
||
<div className="flex items-center gap-4 bg-white p-3 rounded-md border border-gray-200 shadow-sm">
|
||
<label
|
||
htmlFor="rank-slider"
|
||
className="text-xs font-bold text-ink uppercase"
|
||
>
|
||
Rank (r): {rank}
|
||
</label>
|
||
<input
|
||
id="rank-slider"
|
||
type="range"
|
||
min="1"
|
||
max="16"
|
||
step="1"
|
||
value={rank}
|
||
onChange={(e) => setRank(parseInt(e.target.value))}
|
||
className="w-32 h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer accent-accent"
|
||
/>
|
||
</div>
|
||
</div>
|
||
|
||
<div className="flex flex-wrap items-center justify-center gap-4 md:gap-12 overflow-x-auto py-4 min-h-[200px]">
|
||
{/* Matrix B (d x r) */}
|
||
<div className="flex flex-col items-center group">
|
||
<div className="mb-2 text-xs font-mono text-blue-600 font-bold">
|
||
B
|
||
</div>
|
||
{renderGrid(16, rank, "#93c5fd", `${realD} × ${rank}`)}
|
||
</div>
|
||
|
||
<div className="text-lg text-gray-400 font-light">×</div>
|
||
|
||
{/* Matrix A (r x k) */}
|
||
<div className="flex flex-col items-center group">
|
||
<div className="mb-2 text-xs font-mono text-blue-700 font-bold">
|
||
A
|
||
</div>
|
||
{renderGrid(rank, 16, "#2563eb", `${rank} × ${realK}`)}
|
||
</div>
|
||
|
||
{/* Arrow */}
|
||
<div className="text-lg text-gray-400 font-light">⟶</div>
|
||
|
||
{/* Matrix Delta W */}
|
||
<div className="flex flex-col items-center opacity-50 grayscale">
|
||
<div className="mb-2 text-xs font-mono text-green-600 font-bold">
|
||
ΔW
|
||
</div>
|
||
{renderGrid(16, 16, "#4ade80", `${realD} × ${realK}`)}
|
||
</div>
|
||
</div>
|
||
|
||
{/* Math Dashboard */}
|
||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 mt-8 border-t border-gray-200 pt-6">
|
||
<div className="bg-white p-4 rounded border border-gray-100">
|
||
<div className="text-[10px] uppercase tracking-widest text-subtle mb-1">
|
||
Trainable Params (LoRA)
|
||
</div>
|
||
<div className="font-mono text-xl text-blue-600 font-bold">
|
||
{paramsLoRA.toLocaleString()}
|
||
</div>
|
||
<div className="text-[10px] text-gray-400 font-mono mt-1">
|
||
({realD}×{rank}) + ({rank}×{realK})
|
||
</div>
|
||
</div>
|
||
|
||
<div className="bg-white p-4 rounded border border-gray-100">
|
||
<div className="text-[10px] uppercase tracking-widest text-subtle mb-1">
|
||
Trainable Params (Full)
|
||
</div>
|
||
<div className="font-mono text-xl text-gray-400">
|
||
{paramsFull.toLocaleString()}
|
||
</div>
|
||
<div className="text-[10px] text-gray-400 font-mono mt-1">
|
||
{realD} × {realK}
|
||
</div>
|
||
</div>
|
||
|
||
<div className="bg-green-50 p-4 rounded border border-green-100">
|
||
<div className="text-[10px] uppercase tracking-widest text-green-800 mb-1">
|
||
Memory Reduction
|
||
</div>
|
||
<div className="font-mono text-xl text-green-600 font-bold">
|
||
{reduction}%
|
||
</div>
|
||
<div className="text-[10px] text-green-700 font-mono mt-1">
|
||
More VRAM for batch size
|
||
</div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
);
|
||
};
|