In [ ]:
import os
import time
import shutil
import numpy as np
import rasterio
from rasterio.windows import from_bounds
import cv2
import ee
import geemap
from google.colab import drive
from sklearn.model_selection import train_test_split

try:
    import geedim
except ImportError:
    import subprocess
    subprocess.check_call(["pip", "install", "geedim"])
    subprocess.check_call(["pip", "install", "geemap"])

# 1. SETUP & AUTHENTICATION (Robust Fix)
drive.mount('/content/drive', force_remount=True)

def initialize_ee():
    try:
        ee.Initialize(project='[REDACTED_FOR_SECURITY]')
        print("Earth Engine Initialized Successfully.")
    except Exception as e:
        print(f"Initialization failed. Triggering Authentication... ({e})")
        ee.Authenticate()
        ee.Initialize(project='[REDACTED_FOR_SECURITY]')
        print("Earth Engine Authenticated & Initialized.")

initialize_ee()

ASSET_ID = 'projects/[REDACTED_FOR_SECURITY]/assets/Punjab_Mask_2024_NEW'
SAVE_DIR = '/content/drive/MyDrive/Prithvi_PartialFT_Results/'
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)

PATCH_SIZE = 224
S2_SCALE = 5000.0

TIME_WINDOWS = [
    ('2024-10-15', '2024-11-15'),  # T1
    ('2025-01-01', '2025-01-31'),  # T2
    ('2025-02-15', '2025-03-15')   # T3
]

def generate_prithvi_npy():
    print(f"Starting Data Generation. Project: [REDACTED_FOR_SECURITY]")

    mask_img = ee.Image(ASSET_ID)
    roi_geom = mask_img.geometry()
    mask_file = 'local_mask_prithvi.tif'

    if not os.path.exists(mask_file):
        print("Downloading Mask...")
        geemap.download_ee_image(mask_img, mask_file, region=roi_geom, scale=10, crs='EPSG:4326', overwrite=True)

    with rasterio.open(mask_file) as src:
        b = src.bounds
        cx, cy = (b.left + b.right)/2, (b.bottom + b.top)/2
        offset = 0.04
        window = from_bounds(cx-offset, cy-offset, cx+offset, cy+offset, src.transform)
        mask = src.read(1, window=window)
        mask = np.where(mask > 0, 1.0, 0.0).astype(np.float32)
        target_h, target_w = mask.shape
        small_roi = ee.Geometry.BBox(cx-offset, cy-offset, cx+offset, cy+offset)
        print(f"ROI: {target_h}x{target_w}")

    stack = []
    for i, (start, end) in enumerate(TIME_WINDOWS):
        fname = f'prithvi_time_{i}.tif'
        attempts = 0
        while not os.path.exists(fname) and attempts < 3:
            try:
                print(f"Downloading T{i+1}: {start} to {end}...")
                img = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \
                    .filterBounds(small_roi).filterDate(start, end).median() \
                    .select(['B2','B3','B4','B8','B11','B12'])
                geemap.download_ee_image(img, fname, region=small_roi, scale=10, crs='EPSG:4326', overwrite=True)
            except:
                attempts += 1
                time.sleep(2)

        if not os.path.exists(fname):
            if i > 0: shutil.copy(f'prithvi_time_{i-1}.tif', fname)
            else:
                with rasterio.open(mask_file) as src:
                     p = src.profile
                     p.update(count=6, dtype=rasterio.float32)
                     with rasterio.open(fname, 'w', **p) as dst: dst.write(np.zeros((6, target_h, target_w), dtype=np.float32))

        with rasterio.open(fname) as src:
            arr = src.read()
            arr = np.transpose(arr, (1, 2, 0))
            if arr.shape[:2] != (target_h, target_w):
                arr = cv2.resize(arr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
            arr = np.clip(arr / S2_SCALE, 0, 1).astype(np.float32)
            stack.append(arr)

    full_cube = np.stack(stack, axis=2)
    x_out, y_out = [], []
    stride = PATCH_SIZE

    print("Tiling...")
    for y in range(0, target_h, stride):
        for x in range(0, target_w, stride):
            img_p = full_cube[y:y+stride, x:x+stride]
            mask_p = mask[y:y+stride, x:x+stride]
            if img_p.shape[0] != PATCH_SIZE or img_p.shape[1] != PATCH_SIZE: continue
            if np.mean(mask_p) < 0.01: continue
            if np.isnan(img_p).any(): continue
            x_out.append(img_p)
            y_out.append(mask_p)

    X = np.array(x_out, dtype=np.float32).transpose(0, 4, 3, 1, 2)
    y = np.array(y_out, dtype=np.float32)[:, None, :, :]

    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
    np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
    np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
    np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
    np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
    print("Data Generation Complete.")

generate_prithvi_npy()
Mounted at /content/drive
Initialization failed. Triggering Authentication... (Please authorize access to your Earth Engine account by running

earthengine authenticate

in your command line, or ee.Authenticate() in Python, and then retry.)
Earth Engine Authenticated & Initialized.
Starting Data Generation. Project: satmae-2026
Downloading Mask...
/usr/local/lib/python3.12/dist-packages/geemap/common.py:12471: FutureWarning: 'BaseImage' is deprecated and will be removed in a future release.  Please use the 'ee.Image.gd' accessor instead.
  img = gd.download.BaseImage(image)
...tmae-2026/assets/Punjab_Mask_2024_NEW:   0%|          |0/585 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:googleapiclient.http:Sleeping 0.82 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.00 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.81 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.12 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.85 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:googleapiclient.http:Sleeping 1.71 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:googleapiclient.http:Sleeping 1.30 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.45 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
/usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'projects/satmae-2026/assets/Punjab_Mask_2024_NEW'.
  return STACClient().get(self.id)
ROI: 891x891
Downloading T1: 2024-10-15 to 2024-11-15...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
/usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'None'.
  return STACClient().get(self.id)
Downloading T2: 2025-01-01 to 2025-01-31...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Downloading T3: 2025-02-15 to 2025-03-15...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Tiling...
Data Generation Complete.
In [ ]:
!pip install geedim
!pip install geemap
!pip install segmentation-models-pytorch
Requirement already satisfied: geedim in /usr/local/lib/python3.12/dist-packages (2.0.0)
Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.12/dist-packages (from geedim) (2.0.2)
Requirement already satisfied: rasterio>=1.3.8 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.0)
Requirement already satisfied: click>=8 in /usr/local/lib/python3.12/dist-packages (from geedim) (8.3.1)
Requirement already satisfied: tqdm>=4.6 in /usr/local/lib/python3.12/dist-packages (from geedim) (4.67.1)
Requirement already satisfied: earthengine-api>=0.1.379 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.24)
Requirement already satisfied: tabulate>=0.9 in /usr/local/lib/python3.12/dist-packages (from geedim) (0.9.0)
Requirement already satisfied: fsspec>=2025.2 in /usr/local/lib/python3.12/dist-packages (from geedim) (2025.3.0)
Requirement already satisfied: aiohttp>=3.11 in /usr/local/lib/python3.12/dist-packages (from geedim) (3.13.3)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.4.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (25.4.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (0.4.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.22.0)
Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (3.7.0)
Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.187.0)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.43.0)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.3.0)
Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.31.0)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.32.4)
Requirement already satisfied: affine in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2.4.0)
Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2026.1.4)
Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (0.7.2)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (3.3.1)
Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.12/dist-packages (from aiosignal>=1.4.0->aiohttp>=3.11->geedim) (4.15.0)
Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (2.29.0)
Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (4.2.0)
Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (6.2.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.4.2)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (4.9.1)
Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.12/dist-packages (from yarl<2.0,>=1.17.0->aiohttp>=3.11->geedim) (3.11)
Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.5.0)
Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.8.0)
Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (1.8.0)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (3.4.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (2.5.0)
Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.72.0)
Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (5.29.5)
Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.27.0)
Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.6.1)
Requirement already satisfied: geemap in /usr/local/lib/python3.12/dist-packages (0.35.3)
Requirement already satisfied: bqplot in /usr/local/lib/python3.12/dist-packages (from geemap) (0.12.45)
Requirement already satisfied: colour in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.5)
Requirement already satisfied: earthengine-api>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (1.5.24)
Requirement already satisfied: eerepr>=0.1.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.2)
Requirement already satisfied: folium>=0.17.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0)
Requirement already satisfied: geocoder in /usr/local/lib/python3.12/dist-packages (from geemap) (1.38.1)
Requirement already satisfied: ipyevents in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.4)
Requirement already satisfied: ipyfilechooser>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.6.0)
Requirement already satisfied: ipyleaflet>=0.19.2 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0)
Requirement already satisfied: ipytree in /usr/local/lib/python3.12/dist-packages (from geemap) (0.2.2)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from geemap) (3.10.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from geemap) (2.2.2)
Requirement already satisfied: plotly in /usr/local/lib/python3.12/dist-packages (from geemap) (5.24.1)
Requirement already satisfied: pyperclip in /usr/local/lib/python3.12/dist-packages (from geemap) (1.11.0)
Requirement already satisfied: pyshp>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from geemap) (3.0.3)
Requirement already satisfied: python-box in /usr/local/lib/python3.12/dist-packages (from geemap) (7.3.2)
Requirement already satisfied: scooby in /usr/local/lib/python3.12/dist-packages (from geemap) (0.11.0)
Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (3.7.0)
Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.187.0)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.43.0)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.3.0)
Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.31.0)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.32.4)
Requirement already satisfied: branca>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (0.8.2)
Requirement already satisfied: jinja2>=2.9 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (3.1.6)
Requirement already satisfied: xyzservices in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (2025.11.0)
Requirement already satisfied: ipywidgets in /usr/local/lib/python3.12/dist-packages (from ipyfilechooser>=0.6.0->geemap) (7.7.1)
Requirement already satisfied: jupyter-leaflet<0.21,>=0.20 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.20.0)
Requirement already satisfied: traittypes<3,>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.2.3)
Requirement already satisfied: traitlets>=4.3.0 in /usr/local/lib/python3.12/dist-packages (from bqplot->geemap) (5.7.1)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.3)
Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (8.3.1)
Requirement already satisfied: future in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.0.0)
Requirement already satisfied: ratelim in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (0.1.6)
Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.17.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (4.61.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.4.9)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (25.0)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (11.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (3.3.1)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.12/dist-packages (from plotly->geemap) (9.1.2)
Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (2.29.0)
Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (4.2.0)
Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (6.2.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.4.2)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (4.9.1)
Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.17.1)
Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0)
Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.6.10)
Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.34.0)
Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.16)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2>=2.9->folium>=0.17.0->geemap) (3.0.3)
Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.5.0)
Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.8.0)
Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (1.8.0)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2026.1.4)
Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from ratelim->geocoder->geemap) (4.4.2)
Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.72.0)
Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (5.29.5)
Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.27.0)
Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.8.15)
Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.4.9)
Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.1)
Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.6.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.5)
Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (26.2.1)
Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.1)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (75.2.0)
Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.19.2)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.5)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.52)
Requirement already satisfied: pygments in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.19.2)
Requirement already satisfied: backcall in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0)
Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.9.0)
Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.6.1)
Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.12/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.7)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.12/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.8.5)
Requirement already satisfied: entrypoints in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.4)
Requirement already satisfied: jupyter-core>=4.9.2 in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.1)
Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0)
Requirement already satisfied: nbformat in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.10.4)
Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.16.6)
Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0)
Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.18.1)
Requirement already satisfied: prometheus-client in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.23.1)
Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.3)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.12/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.12/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.14)
Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.12/dist-packages (from jupyter-core>=4.9.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.5.1)
Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.12/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.4)
Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.13.5)
Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.3.0)
Requirement already satisfied: defusedxml in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.1)
Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.3.0)
Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.2.0)
Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.10.4)
Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1)
Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.21.2)
Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.26.0)
Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.12/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0)
Requirement already satisfied: webencodings in /usr/local/lib/python3.12/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.1)
Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0)
Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.4.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2025.9.1)
Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.37.0)
Requirement already satisfied: rpds-py>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.30.0)
Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.12/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.14.0)
Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0)
Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.8.1)
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.15.0)
Requirement already satisfied: pycparser in /usr/local/lib/python3.12/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.23)
Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.12.1)
Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.12.0)
Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.3)
Requirement already satisfied: overrides>=5.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.7.0)
Requirement already satisfied: websocket-client>=1.7 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.9.0)
Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.0.0)
Requirement already satisfied: pyyaml>=5.3 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.0.3)
Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.4)
Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.1)
Requirement already satisfied: fqdn in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1)
Requirement already satisfied: isoduration in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (20.11.0)
Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.0)
Requirement already satisfied: rfc3987-syntax>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.1.0)
Requirement already satisfied: uri-template in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.0)
Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.10.0)
Requirement already satisfied: lark>=1.2.2 in /usr/local/lib/python3.12/dist-packages (from rfc3987-syntax>=1.1.0->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.1)
Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.12/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0)
Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Requirement already satisfied: huggingface-hub>=0.24 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.36.0)
Requirement already satisfied: numpy>=1.19.3 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (2.0.2)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (11.3.0)
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.7.0)
Requirement already satisfied: timm>=0.9 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (1.0.24)
Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (2.9.0+cu126)
Requirement already satisfied: torchvision>=0.9 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.24.0+cu126)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (4.67.1)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (3.20.2)
Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (2025.3.0)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (25.0)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (6.0.3)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (2.32.4)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (4.15.0)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (1.2.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.6.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.1.6)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.80)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (11.3.0.4)
Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (10.3.7.77)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.5.4.2)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.3.20)
Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.85)
Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (1.11.1.6)
Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.5.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.8->segmentation-models-pytorch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.8->segmentation-models-pytorch) (3.0.3)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (2026.1.4)
Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.8/154.8 kB 6.4 MB/s eta 0:00:00
Installing collected packages: segmentation-models-pytorch
Successfully installed segmentation-models-pytorch-0.5.0
In [ ]:
import os
import requests
from google.colab import drive

drive.mount('/content/drive', force_remount=True)

SAVE_PATH = "/content/drive/MyDrive/Prithvi_100M.pt"
DOWNLOAD_URL = "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M/resolve/main/Prithvi_100M.pt"

if not os.path.exists(SAVE_PATH):
    print(f"Downloading weights to: {SAVE_PATH}")
    try:
        response = requests.get(DOWNLOAD_URL, stream=True)
        response.raise_for_status()
        with open(SAVE_PATH, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print("Download Complete.")
    except Exception as e:
        print(f"Download Failed: {e}")
else:
    print(f"Weights exist at: {SAVE_PATH}")
Mounted at /content/drive
Downloading weights to: /content/drive/MyDrive/Prithvi_100M.pt
Download Complete.
In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import distance_transform_edt as distance
import segmentation_models_pytorch as smp

def apply_training_augmentation(x, y):
    if np.random.rand() > 0.5:
        x = torch.flip(x, [4])
        y = torch.flip(y, [3])
    if np.random.rand() > 0.5:
        x = torch.flip(x, [3])
        y = torch.flip(y, [2])
    k = np.random.randint(0, 4)
    x = torch.rot90(x, k, [3, 4])
    y = torch.rot90(y, k, [2, 3])
    if np.random.rand() > 0.5:
        noise = (torch.rand(x.shape[0], 1, 1, 1, 1, device=x.device) * 0.2) + 0.9
        x = x * noise
    return x.contiguous(), y.contiguous()

class TestTimeAugmentation(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        pred_orig = self.model(x)
        x_hflip = torch.flip(x, [4])
        pred_hflip = torch.flip(self.model(x_hflip), [3])
        x_vflip = torch.flip(x, [3])
        pred_vflip = torch.flip(self.model(x_vflip), [2])
        x_rot = torch.rot90(x, 1, [3, 4])
        pred_rot = torch.rot90(self.model(x_rot), -1, [2, 3])
        return torch.stack([pred_orig, pred_hflip, pred_vflip, pred_rot]).mean(dim=0)

def validate_with_tta(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    tta_model = TestTimeAugmentation(model)
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = tta_model(x)
            loss = criterion(preds, y)
            total_loss += loss.item()
    return total_loss / len(loader)

class HausdorffDTLoss(nn.Module):
    def __init__(self, alpha=2.0):
        super().__init__()
        self.alpha = alpha
    def forward(self, pred, gt):
        with torch.no_grad():
            gt_np = gt.cpu().numpy()
            dist_map = np.zeros_like(gt_np)
            for i in range(len(gt_np)):
                mask = (gt_np[i, 0] > 0.5).astype(np.uint8)
                if mask.sum() == 0:
                    dist_map[i, 0] = np.ones_like(mask) * 100.0
                    continue
                d_in = distance(mask)
                d_out = distance(1 - mask)
                dist_map[i, 0] = (d_out - d_in)
            dist_map = torch.tensor(dist_map, device=pred.device, dtype=torch.float32)
        probs = torch.sigmoid(pred)
        return torch.mean((probs - gt) ** 2 * (1 + self.alpha * torch.abs(dist_map)))

class CompoundLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dice = smp.losses.DiceLoss(mode='binary', from_logits=True)
        self.hausdorff = HausdorffDTLoss(alpha=2.0)
    def forward(self, p, t):
        return 0.7 * self.dice(p, t) + 0.3 * self.hausdorff(p, t)
In [ ]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset

# CONFIGURATION (For Model Init)
CONFIG = {
    "EPOCHS": 5000,
    "PATIENCE": 200,
    "BATCH_SIZE": 8,
    "LEARNING_RATE": 1e-4,
    "WEIGHT_DECAY": 0.05,
    "WARMUP_EPOCHS": 20,
    "SWA_START_EPOCH": None,
    "RESUME": True,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "IMG_SIZE": 224,
    "NUM_FRAMES": 3,
    "IN_CHANS": 6,
    "EMBED_DIM": 768,
    "DEPTH": 12,
    "NUM_HEADS": 12,
    "PATCH_SIZE": 16,
    "PRETRAINED_PATH": "/content/drive/MyDrive/Prithvi_100M.pt",
    "SAVE_DIR": SAVE_DIR
}

class PunjabWheatDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.data = np.load(x_path, mmap_mode='r')
        self.masks = np.load(y_path, mmap_mode='r')
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img = self.data[idx]
        mask = self.masks[idx]
        return torch.from_numpy(img.copy()).float(), torch.from_numpy(mask.copy()).float()

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size):
    grid_h = np.arange(grid_size, dtype=np.float32)
    pos_embed_h = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_h)
    grid_w = np.arange(grid_size, dtype=np.float32)
    pos_embed_w = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_w)
    grid_t = np.arange(t_size, dtype=np.float32)
    pos_embed_t = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)

    pos_t = pos_embed_t[:, np.newaxis, np.newaxis, :]
    pos_h = pos_embed_h[np.newaxis, :, np.newaxis, :]
    pos_w = pos_embed_w[np.newaxis, np.newaxis, :, :]

    pos_embed = pos_t + pos_h + pos_w
    pos_embed = pos_embed.reshape(-1, embed_dim)
    return torch.from_numpy(pos_embed).float().unsqueeze(0)

class PatchEmbed3D(nn.Module):
    def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class PrithviPartialFT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = PatchEmbed3D(patch_size=CONFIG["PATCH_SIZE"], frames=CONFIG["NUM_FRAMES"], in_chans=CONFIG["IN_CHANS"], embed_dim=CONFIG["EMBED_DIM"])
        self.register_buffer("pos_embed", get_3d_sincos_pos_embed(CONFIG["EMBED_DIM"], CONFIG["IMG_SIZE"] // CONFIG["PATCH_SIZE"], CONFIG["NUM_FRAMES"]))

        encoder_layer = nn.TransformerEncoderLayer(d_model=CONFIG["EMBED_DIM"], nhead=CONFIG["NUM_HEADS"], dim_feedforward=CONFIG["EMBED_DIM"]*4, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=CONFIG["DEPTH"])

        in_dim = CONFIG["EMBED_DIM"] * CONFIG["NUM_FRAMES"]
        self.decoder = nn.Sequential(
            nn.Conv2d(in_dim, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 1, kernel_size=1)
        )
        self._init_decoder()

    def _init_decoder(self):
        for m in self.decoder.modules():
            if isinstance(m, (nn.Conv2d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.constant_(m.bias, 0)
        print("Decoder initialized.")

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.encoder(x)
        H_p = CONFIG["IMG_SIZE"] // CONFIG["PATCH_SIZE"]
        x = x.transpose(1, 2).view(B, CONFIG["EMBED_DIM"], CONFIG["NUM_FRAMES"], H_p, H_p)
        x = x.reshape(B, -1, H_p, H_p)
        return self.decoder(x)
In [6]:
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

train_x = os.path.join(SAVE_DIR, 'train_x.npy')
train_y = os.path.join(SAVE_DIR, 'train_y.npy')
val_x = os.path.join(SAVE_DIR, 'val_x.npy')
val_y = os.path.join(SAVE_DIR, 'val_y.npy')

if not os.path.exists(train_x):
    raise FileNotFoundError(f"Data not found in {SAVE_DIR}. Run Cell 1 first.")

print("Loading Data...")
train_loader = DataLoader(PunjabWheatDataset(train_x, train_y), batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=2)
val_loader = DataLoader(PunjabWheatDataset(val_x, val_y), batch_size=CONFIG["BATCH_SIZE"], shuffle=False, num_workers=2)

model = PrithviPartialFT().to(CONFIG["DEVICE"])

# --- WEIGHT LOADING ---
if os.path.exists(CONFIG["PRETRAINED_PATH"]):
    print(f"Loading Weights from {CONFIG['PRETRAINED_PATH']}...")
    ckpt = torch.load(CONFIG["PRETRAINED_PATH"], map_location='cpu')
    state_dict = ckpt['model'] if 'model' in ckpt else ckpt
    new_state_dict = {}
    for k, v in state_dict.items():
        if 'blocks' in k:
            new_key = k.replace('blocks', 'encoder.layers')
            new_state_dict[new_key] = v
        elif 'patch_embed' in k: new_state_dict[k] = v
        elif 'pos_embed' in k:
            if v.shape == model.pos_embed.shape: new_state_dict[k] = v
    model.load_state_dict(new_state_dict, strict=False)
    print("Pretrained Weights Loaded.")
else:
    print("WARNING: PRETRAINED WEIGHTS NOT FOUND.")

# --- PARTIAL FREEZING STRATEGY (10 Frozen / 2 Unfrozen) ---
print("\nConfiguring Layers...")

# 1. Train Embeddings
for param in model.patch_embed.parameters(): param.requires_grad = True
print("- Embeddings: LEARNABLE")

# 2. Configure Encoder (0-9 Frozen, 10-11 Trainable)
for i, layer in enumerate(model.encoder.layers):
    if i < 10:
        for param in layer.parameters(): param.requires_grad = False
    else:
        for param in layer.parameters(): param.requires_grad = True
print(f"- Encoder Layers 0-9: FROZEN")
print(f"- Encoder Layers 10-11: LEARNABLE")

# 3. Train Decoder
for param in model.decoder.parameters(): param.requires_grad = True
print("- Decoder: LEARNABLE\n")

criterion = CompoundLoss().to(CONFIG["DEVICE"])

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.AdamW(trainable_params, lr=CONFIG["LEARNING_RATE"], weight_decay=CONFIG["WEIGHT_DECAY"])

scheduler_warmup = LinearLR(optimizer, start_factor=0.01, total_iters=CONFIG["WARMUP_EPOCHS"])
scheduler_cosine = CosineAnnealingLR(optimizer, T_max=CONFIG["EPOCHS"] - CONFIG["WARMUP_EPOCHS"], eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_cosine], milestones=[CONFIG["WARMUP_EPOCHS"]])

swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
swa_active = False

checkpoint_path = os.path.join(CONFIG["SAVE_DIR"], "checkpoint.pth")
start_epoch = 0
best_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': []}

if CONFIG["RESUME"] and os.path.exists(checkpoint_path):
    print("Found checkpoint. Loading...")
    try:
        ckpt = torch.load(checkpoint_path)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        start_epoch = ckpt['epoch'] + 1
        best_loss = ckpt.get('best_loss', float('inf'))
        history = ckpt.get('history', {'train_loss': [], 'val_loss': []})
        if 'swa_model' in ckpt: swa_model.load_state_dict(ckpt['swa_model'])
        if 'swa_active' in ckpt: swa_active = ckpt['swa_active']
        print(f"Resumed from Epoch {start_epoch}")
    except:
        start_epoch = 0

print("Starting Partial Fine-Tuning...")
for epoch in range(start_epoch, CONFIG["EPOCHS"]):
    model.train()
    train_loss = 0

    for x, y in train_loader:
        x, y = x.to(CONFIG["DEVICE"]), y.to(CONFIG["DEVICE"])
        x, y = apply_training_augmentation(x, y)

        optimizer.zero_grad()
        preds = model(x)
        loss = criterion(preds, y.contiguous())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = validate_with_tta(model, val_loader, criterion, CONFIG["DEVICE"])

    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)

    SWA_TRIGGER = CONFIG["PATIENCE"] // 2
    if patience_counter >= SWA_TRIGGER and not swa_active:
        print(f"Stagnation (Patience {patience_counter}). Activating SWA.")
        swa_active = True
        swa_model.update_parameters(model)

    if swa_active:
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()

    checkpoint_dict = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_loss': best_loss,
        'swa_model': swa_model.state_dict(),
        'history': history,
        'swa_active': swa_active
    }
    torch.save(checkpoint_dict, checkpoint_path)

    if epoch % 10 == 0:
        torch.save(checkpoint_dict, os.path.join(CONFIG["SAVE_DIR"], f"checkpoint_epoch_{epoch}.pth"))

    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(model.state_dict(), os.path.join(CONFIG["SAVE_DIR"], "best_model.pth"))
        patience_counter = 0
        print(f"New Best Model! Loss: {best_loss:.4f}")
    else:
        patience_counter += 1

    if patience_counter >= CONFIG["PATIENCE"]:
        print(f"Early Stopping at Epoch {epoch}")
        if swa_active:
            torch.save(swa_model.state_dict(), os.path.join(CONFIG["SAVE_DIR"], "swa_model_final.pth"))
        break

    if epoch % 5 == 0:
        mode = "SWA" if swa_active else "STD"
        print(f"Ep {epoch} [{mode}] | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
Loading Data...
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True
  warnings.warn(
Decoder initialized.
Loading Weights from /content/drive/MyDrive/Prithvi_100M.pt...
Pretrained Weights Loaded.

Configuring Layers...
- Embeddings: LEARNABLE
- Encoder Layers 0-9: FROZEN
- Encoder Layers 10-11: LEARNABLE
- Decoder: LEARNABLE

Starting Partial Fine-Tuning...
New Best Model! Loss: 0.9044
Ep 0 [STD] | Train: 1.2669 | Val: 0.9044
New Best Model! Loss: 0.8892
New Best Model! Loss: 0.8671
New Best Model! Loss: 0.8527
Ep 5 [STD] | Train: 0.8977 | Val: 0.8527
New Best Model! Loss: 0.8485
Ep 10 [STD] | Train: 0.7293 | Val: 0.8817
Ep 15 [STD] | Train: 0.6343 | Val: 0.8854
Ep 20 [STD] | Train: 0.5313 | Val: 0.8954
Ep 25 [STD] | Train: 0.4535 | Val: 0.8913
New Best Model! Loss: 0.7621
Ep 30 [STD] | Train: 0.4097 | Val: 0.7621
New Best Model! Loss: 0.6572
New Best Model! Loss: 0.5596
New Best Model! Loss: 0.4836
New Best Model! Loss: 0.4594
New Best Model! Loss: 0.4345
Ep 35 [STD] | Train: 0.3590 | Val: 0.4345
New Best Model! Loss: 0.4039
New Best Model! Loss: 0.3933
New Best Model! Loss: 0.3867
New Best Model! Loss: 0.3665
Ep 40 [STD] | Train: 0.3295 | Val: 0.3665
New Best Model! Loss: 0.3521
New Best Model! Loss: 0.3386
New Best Model! Loss: 0.3289
New Best Model! Loss: 0.3194
New Best Model! Loss: 0.2993
Ep 45 [STD] | Train: 0.3180 | Val: 0.2993
New Best Model! Loss: 0.2901
New Best Model! Loss: 0.2849
New Best Model! Loss: 0.2825
New Best Model! Loss: 0.2811
New Best Model! Loss: 0.2791
Ep 50 [STD] | Train: 0.3030 | Val: 0.2791
New Best Model! Loss: 0.2775
New Best Model! Loss: 0.2708
Ep 55 [STD] | Train: 0.2963 | Val: 0.2708
New Best Model! Loss: 0.2661
New Best Model! Loss: 0.2647
Ep 60 [STD] | Train: 0.2874 | Val: 0.2821
New Best Model! Loss: 0.2644
New Best Model! Loss: 0.2604
Ep 65 [STD] | Train: 0.2818 | Val: 0.2604
New Best Model! Loss: 0.2591
New Best Model! Loss: 0.2589
Ep 70 [STD] | Train: 0.2744 | Val: 0.2594
New Best Model! Loss: 0.2580
New Best Model! Loss: 0.2569
New Best Model! Loss: 0.2552
New Best Model! Loss: 0.2530
Ep 75 [STD] | Train: 0.2693 | Val: 0.2530
New Best Model! Loss: 0.2527
Ep 80 [STD] | Train: 0.2625 | Val: 0.2564
New Best Model! Loss: 0.2519
Ep 85 [STD] | Train: 0.2546 | Val: 0.2519
New Best Model! Loss: 0.2496
New Best Model! Loss: 0.2483
Ep 90 [STD] | Train: 0.2598 | Val: 0.2483
New Best Model! Loss: 0.2459
New Best Model! Loss: 0.2455
Ep 95 [STD] | Train: 0.2572 | Val: 0.2512
Ep 100 [STD] | Train: 0.2523 | Val: 0.2491
New Best Model! Loss: 0.2450
New Best Model! Loss: 0.2414
New Best Model! Loss: 0.2394
New Best Model! Loss: 0.2392
Ep 105 [STD] | Train: 0.2429 | Val: 0.2396
New Best Model! Loss: 0.2381
New Best Model! Loss: 0.2364
New Best Model! Loss: 0.2351
Ep 110 [STD] | Train: 0.2384 | Val: 0.2351
Ep 115 [STD] | Train: 0.2311 | Val: 0.2455
Ep 120 [STD] | Train: 0.2300 | Val: 0.2380
New Best Model! Loss: 0.2332
New Best Model! Loss: 0.2316
New Best Model! Loss: 0.2307
New Best Model! Loss: 0.2289
Ep 125 [STD] | Train: 0.2294 | Val: 0.2289
New Best Model! Loss: 0.2224
New Best Model! Loss: 0.2211
New Best Model! Loss: 0.2199
Ep 130 [STD] | Train: 0.2274 | Val: 0.2258
New Best Model! Loss: 0.2185
New Best Model! Loss: 0.2164
Ep 135 [STD] | Train: 0.2221 | Val: 0.2185
New Best Model! Loss: 0.2156
New Best Model! Loss: 0.2137
New Best Model! Loss: 0.2130
Ep 140 [STD] | Train: 0.2176 | Val: 0.2130
Ep 145 [STD] | Train: 0.2164 | Val: 0.2162
New Best Model! Loss: 0.2126
New Best Model! Loss: 0.2093
New Best Model! Loss: 0.2086
New Best Model! Loss: 0.2085
Ep 150 [STD] | Train: 0.2061 | Val: 0.2085
Ep 155 [STD] | Train: 0.2091 | Val: 0.2111
Ep 160 [STD] | Train: 0.1939 | Val: 0.2137
New Best Model! Loss: 0.2065
Ep 165 [STD] | Train: 0.1871 | Val: 0.2071
New Best Model! Loss: 0.2017
Ep 170 [STD] | Train: 0.2006 | Val: 0.2022
New Best Model! Loss: 0.2001
New Best Model! Loss: 0.1977
New Best Model! Loss: 0.1971
Ep 175 [STD] | Train: 0.1950 | Val: 0.1983
New Best Model! Loss: 0.1970
New Best Model! Loss: 0.1961
Ep 180 [STD] | Train: 0.1871 | Val: 0.1969
Ep 185 [STD] | Train: 0.1846 | Val: 0.2119
New Best Model! Loss: 0.1956
Ep 190 [STD] | Train: 0.1834 | Val: 0.1956
New Best Model! Loss: 0.1952
Ep 195 [STD] | Train: 0.1749 | Val: 0.1953
New Best Model! Loss: 0.1929
New Best Model! Loss: 0.1917
New Best Model! Loss: 0.1908
Ep 200 [STD] | Train: 0.1779 | Val: 0.1954
New Best Model! Loss: 0.1907
New Best Model! Loss: 0.1889
Ep 205 [STD] | Train: 0.1723 | Val: 0.1889
Ep 210 [STD] | Train: 0.1702 | Val: 0.1936
New Best Model! Loss: 0.1882
Ep 215 [STD] | Train: 0.1707 | Val: 0.1888
Ep 220 [STD] | Train: 0.1767 | Val: 0.1896
New Best Model! Loss: 0.1881
Ep 225 [STD] | Train: 0.1666 | Val: 0.1907
Ep 230 [STD] | Train: 0.1664 | Val: 0.1899
Ep 235 [STD] | Train: 0.1621 | Val: 0.1882
Ep 240 [STD] | Train: 0.1656 | Val: 0.1930
New Best Model! Loss: 0.1876
Ep 245 [STD] | Train: 0.1578 | Val: 0.1876
New Best Model! Loss: 0.1875
New Best Model! Loss: 0.1868
New Best Model! Loss: 0.1857
Ep 250 [STD] | Train: 0.1595 | Val: 0.1865
New Best Model! Loss: 0.1856
New Best Model! Loss: 0.1840
New Best Model! Loss: 0.1839
Ep 255 [STD] | Train: 0.1670 | Val: 0.1839
New Best Model! Loss: 0.1838
Ep 260 [STD] | Train: 0.1532 | Val: 0.1856
Ep 265 [STD] | Train: 0.1558 | Val: 0.1902
Ep 270 [STD] | Train: 0.1556 | Val: 0.1878
New Best Model! Loss: 0.1834
New Best Model! Loss: 0.1815
New Best Model! Loss: 0.1811
Ep 275 [STD] | Train: 0.1631 | Val: 0.1826
Ep 280 [STD] | Train: 0.1582 | Val: 0.1880
Ep 285 [STD] | Train: 0.1577 | Val: 0.1866
Ep 290 [STD] | Train: 0.1443 | Val: 0.1837
Ep 295 [STD] | Train: 0.1539 | Val: 0.1822
Ep 300 [STD] | Train: 0.1479 | Val: 0.1852
New Best Model! Loss: 0.1804
Ep 305 [STD] | Train: 0.1388 | Val: 0.1804
Ep 310 [STD] | Train: 0.1367 | Val: 0.1822
New Best Model! Loss: 0.1792
Ep 315 [STD] | Train: 0.1495 | Val: 0.1800
Ep 320 [STD] | Train: 0.1447 | Val: 0.1862
Ep 325 [STD] | Train: 0.1574 | Val: 0.1858
Ep 330 [STD] | Train: 0.1428 | Val: 0.1849
New Best Model! Loss: 0.1790
New Best Model! Loss: 0.1787
Ep 335 [STD] | Train: 0.1569 | Val: 0.1794
Ep 340 [STD] | Train: 0.1451 | Val: 0.1885
Ep 345 [STD] | Train: 0.1498 | Val: 0.1846
Ep 350 [STD] | Train: 0.1416 | Val: 0.1809
Ep 355 [STD] | Train: 0.1403 | Val: 0.1805
Ep 360 [STD] | Train: 0.1393 | Val: 0.1824
Ep 365 [STD] | Train: 0.1460 | Val: 0.1832
New Best Model! Loss: 0.1776
Ep 370 [STD] | Train: 0.1408 | Val: 0.1833
Ep 375 [STD] | Train: 0.1400 | Val: 0.1803
Ep 380 [STD] | Train: 0.1448 | Val: 0.1800
New Best Model! Loss: 0.1772
Ep 385 [STD] | Train: 0.1351 | Val: 0.1790
Ep 390 [STD] | Train: 0.1364 | Val: 0.1805
Ep 395 [STD] | Train: 0.1383 | Val: 0.1795
Ep 400 [STD] | Train: 0.1338 | Val: 0.1814
Ep 405 [STD] | Train: 0.1334 | Val: 0.1793
Ep 410 [STD] | Train: 0.1354 | Val: 0.1809
Ep 415 [STD] | Train: 0.1381 | Val: 0.1807
Ep 420 [STD] | Train: 0.1360 | Val: 0.1811
Ep 425 [STD] | Train: 0.1316 | Val: 0.1816
Ep 430 [STD] | Train: 0.1288 | Val: 0.1805
Ep 435 [STD] | Train: 0.1263 | Val: 0.1788
Ep 440 [STD] | Train: 0.1309 | Val: 0.1832
Ep 445 [STD] | Train: 0.1347 | Val: 0.1804
New Best Model! Loss: 0.1759
New Best Model! Loss: 0.1746
Ep 450 [STD] | Train: 0.1272 | Val: 0.1758
Ep 455 [STD] | Train: 0.1279 | Val: 0.1813
Ep 460 [STD] | Train: 0.1230 | Val: 0.1787
Ep 465 [STD] | Train: 0.1238 | Val: 0.1794
Ep 470 [STD] | Train: 0.1258 | Val: 0.1808
Ep 475 [STD] | Train: 0.1214 | Val: 0.1805
Ep 480 [STD] | Train: 0.1200 | Val: 0.1786
Ep 485 [STD] | Train: 0.1246 | Val: 0.1767
Ep 490 [STD] | Train: 0.1240 | Val: 0.1785
Ep 495 [STD] | Train: 0.1197 | Val: 0.1781
Ep 500 [STD] | Train: 0.1213 | Val: 0.1764
Ep 505 [STD] | Train: 0.1184 | Val: 0.1781
Ep 510 [STD] | Train: 0.1207 | Val: 0.1758
New Best Model! Loss: 0.1745
Ep 515 [STD] | Train: 0.1207 | Val: 0.1745
Ep 520 [STD] | Train: 0.1215 | Val: 0.1752
Ep 525 [STD] | Train: 0.1198 | Val: 0.1827
Ep 530 [STD] | Train: 0.1175 | Val: 0.1755
Ep 535 [STD] | Train: 0.1186 | Val: 0.1761
New Best Model! Loss: 0.1743
Ep 540 [STD] | Train: 0.1228 | Val: 0.1743
Ep 545 [STD] | Train: 0.1220 | Val: 0.1769
New Best Model! Loss: 0.1741
Ep 550 [STD] | Train: 0.1220 | Val: 0.1784
New Best Model! Loss: 0.1730
Ep 555 [STD] | Train: 0.1225 | Val: 0.1779
Ep 560 [STD] | Train: 0.1215 | Val: 0.1758
Ep 565 [STD] | Train: 0.1180 | Val: 0.1781
Ep 570 [STD] | Train: 0.1140 | Val: 0.1755
Ep 575 [STD] | Train: 0.1156 | Val: 0.1758
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipython-input-3277825797.py in <cell line: 0>()
    139 
    140     if epoch % 10 == 0:
--> 141         torch.save(checkpoint_dict, os.path.join(CONFIG["SAVE_DIR"], f"checkpoint_epoch_{epoch}.pth"))
    142 
    143     if avg_val_loss < best_loss:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)
    965     if _use_new_zipfile_serialization:
    966         with _open_zipfile_writer(f) as opened_zipfile:
--> 967             _save(
    968                 obj,
    969                 opened_zipfile,

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record)
   1266                     storage = storage.cpu()
   1267             # Now that it is on the CPU we can directly copy it into the zip file
-> 1268             zip_file.write_record(name, storage, num_bytes)
   1269 
   1270 

KeyboardInterrupt: 
In [8]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import os
import json
import time
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import directed_hausdorff
from scipy.ndimage import binary_dilation

# ==========================================
#  CONFIGURATION
# ==========================================
SAVE_DIR = '/content/drive/MyDrive/Prithvi_PartialFT_Results/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================================

print(f" STARTING EVALUATION FOR: {SAVE_DIR}")

# --- 1. DATASET LOADING (Auto-Fix) ---
class PunjabWheatDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.data = np.load(x_path, mmap_mode='r')
        self.masks = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        return torch.from_numpy(self.data[idx].copy()).float(), torch.from_numpy(self.masks[idx].copy()).float()

val_x_path = os.path.join(SAVE_DIR, 'val_x.npy')
val_y_path = os.path.join(SAVE_DIR, 'val_y.npy')

if os.path.exists(val_x_path):
    print(" Dataset loaded successfully.")
    val_ds = PunjabWheatDataset(val_x_path, val_y_path)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=True)
else:
    print(" Dataset not found. Cannot run visualization.")
    val_loader = None

# --- 2. MODEL DEFINITION ---
class PatchEmbed3D(nn.Module):
    def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
    def forward(self, x): return self.proj(x).flatten(2).transpose(1, 2)

class PrithviPartialFT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = PatchEmbed3D()
        self.register_buffer("pos_embed", torch.zeros(1, 3 * (224//16)**2, 768))
        enc_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, dim_feedforward=3072, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=12)
        in_dim = 768 * 3
        self.decoder = nn.Sequential(
            nn.Conv2d(in_dim, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(512, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 1, 1)
        )
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        if self.pos_embed.shape[1] == x.shape[1]: x = x + self.pos_embed
        x = self.encoder(x)
        x = x.transpose(1, 2).view(B, 768, 3, 14, 14).reshape(B, -1, 14, 14)
        return self.decoder(x)

# --- 3. VISUALIZATION (True RGB) ---
def visualize_rgb(model, loader, device, num_samples=3):
    print(f"\n Generating {num_samples} Visualizations (RGB)...")
    model.eval()
    try: x, y = next(iter(loader))
    except: return

    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        preds = torch.sigmoid(model(x))
        masks = (preds > 0.5).float().cpu().numpy()

    x_np, y_np = x.cpu().numpy(), y.cpu().numpy()
    actual = min(num_samples, x_np.shape[0])

    fig, axs = plt.subplots(actual, 3, figsize=(15, 5 * actual))
    plt.suptitle("Input (RGB) | Ground Truth | Prediction", fontsize=16)

    for i in range(actual):
        # RGB Composite: Red(2), Green(1), Blue(0)
        rgb = np.stack([x_np[i,2,1], x_np[i,1,1], x_np[i,0,1]], axis=2)
        p2, p98 = np.percentile(rgb, (2, 98))
        rgb = np.clip((rgb - p2) / (p98 - p2 + 1e-6), 0, 1)

        ax = axs[i] if actual > 1 else axs
        ax[0].imshow(rgb); ax[0].set_title(f"Input Sample {i+1}")
        ax[1].imshow(y_np[i,0], cmap='gray'); ax[1].set_title("Ground Truth")
        ax[2].imshow(masks[i,0], cmap='gray'); ax[2].set_title("Prediction")
        for a in ax: a.axis('off')
    plt.tight_layout(); plt.show()

# --- 4. DEEP METRICS ---
def evaluate_metrics(loader, model, device):
    print("\n Calculating Deep Metrics...")
    model.eval()
    tp, fp, fn, tn = 0, 0, 0, 0
    boundary_ious, hausdorff_dists, times = [], [], []

    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            if i > 50: break # Safety limit for speed
            x = x.to(device)
            y_true = y.cpu().numpy()

            start = time.time()
            preds = model(x)
            if torch.cuda.is_available(): torch.cuda.synchronize()
            times.append((time.time() - start) / x.shape[0])

            y_pred = (torch.sigmoid(preds) > 0.5).float().cpu().numpy()

            # Pixel Metrics
            t = (y_true == 1); p = (y_pred == 1)
            tp += (t & p).sum(); fp += (~t & p).sum(); fn += (t & ~p).sum(); tn += (~t & ~p).sum()

            # Shape Metrics
            for b in range(x.shape[0]):
                pm, gm = y_pred[b,0].astype(bool), y_true[b,0].astype(bool)
                pe = binary_dilation(pm)^pm; ge = binary_dilation(gm)^gm
                i_ = (pe & ge).sum(); u_ = (pe | ge).sum()
                boundary_ious.append(i_/u_ if u_ > 0 else 1.0)
                if pm.sum()>0 and gm.sum()>0:
                    d1 = directed_hausdorff(np.argwhere(pm), np.argwhere(gm))[0]
                    d2 = directed_hausdorff(np.argwhere(gm), np.argwhere(pm))[0]
                    hausdorff_dists.append(max(d1, d2))
            if i % 10 == 0: print(".", end="")

    eps = 1e-6
    res = {
        "Pixel_Accuracy": (tp+tn)/(tp+tn+fp+fn+eps),
        "IoU (Jaccard)": tp/(tp+fp+fn+eps),
        "F1-Score (Dice)": 2*tp/(2*tp+fp+fn+eps),
        "Precision": tp/(tp+fp+eps),
        "Recall": tp/(tp+fn+eps),
        "Boundary_IoU": np.mean(boundary_ious),
        "Hausdorff_Px": np.mean(hausdorff_dists) if hausdorff_dists else 0.0,
        "FPS": 1.0 / np.mean(times)
    }

    print("\n" + "="*40)
    print(" FINAL METRICS REPORT (PARTIAL FT)")
    print("="*40)
    print(json.dumps({k: round(v, 4) for k, v in res.items()}, indent=4))
    print("="*40)

# --- EXECUTION ---
model = PrithviPartialFT().to(DEVICE)

# 1. Try to Load Best Model (Weights Only)
# Priority: Best Model -> SWA -> Latest Checkpoint
model_path = os.path.join(SAVE_DIR, "best_model.pth")
if not os.path.exists(model_path):
    print(" Best model not found, checking SWA...")
    model_path = os.path.join(SAVE_DIR, "swa_model_final.pth")
if not os.path.exists(model_path):
    print(" SWA not found, checking latest checkpoint...")
    model_path = os.path.join(SAVE_DIR, "checkpoint.pth")

if os.path.exists(model_path):
    print(f" Loading Weights from: {os.path.basename(model_path)}...", end=" ")
    try:
        ckpt = torch.load(model_path, map_location=DEVICE)
        sd = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
        new_sd = {k.replace('blocks', 'encoder.layers'): v for k,v in sd.items() if 'pos_embed' not in k or v.shape == model.pos_embed.shape}
        model.load_state_dict(new_sd, strict=False)
        print(" Loaded.")

        # 2. Run Visualization & Metrics
        if val_loader:
            visualize_rgb(model, val_loader, DEVICE, num_samples=3)
            evaluate_metrics(val_loader, model, DEVICE)
    except Exception as e:
        print(f" Failed to load weights: {e}")
else:
    print(" No valid model file found. Please check your Drive folder.")
 STARTING EVALUATION FOR: /content/drive/MyDrive/Prithvi_PartialFT_Results/
 Dataset loaded successfully.
 Loading Weights from: best_model.pth...  Loaded.

 Generating 3 Visualizations (RGB)...
No description has been provided for this image
 Calculating Deep Metrics...
.
========================================
 FINAL METRICS REPORT (PARTIAL FT)
========================================
{
    "Pixel_Accuracy": 0.8864,
    "IoU (Jaccard)": 0.8541,
    "F1-Score (Dice)": 0.9213,
    "Precision": 0.9047,
    "Recall": 0.9386,
    "Boundary_IoU": 0.0719,
    "Hausdorff_Px": 12.877,
    "FPS": 24.5055
}
========================================
In [11]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import os
from torch.utils.data import Dataset, DataLoader

# ==========================================
#  CONFIGURATION
# ==========================================
SAVE_DIR = '/content/drive/MyDrive/Prithvi_PartialFT_Results/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================================

print(f" GENERATING 5 ROBUST SAMPLES FOR: {SAVE_DIR}")

# --- 1. DATASET LOADING ---
class PunjabWheatDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.data = np.load(x_path, mmap_mode='r')
        self.masks = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        return torch.from_numpy(self.data[idx].copy()).float(), torch.from_numpy(self.masks[idx].copy()).float()

val_x_path = os.path.join(SAVE_DIR, 'val_x.npy')
val_y_path = os.path.join(SAVE_DIR, 'val_y.npy')

if os.path.exists(val_x_path):
    print(" Dataset loaded successfully.")
    val_ds = PunjabWheatDataset(val_x_path, val_y_path)
    val_loader = DataLoader(val_ds, batch_size=5, shuffle=True) # Batch size 5 to get enough samples at once
else:
    print(" Dataset not found. Cannot run visualization.")
    val_loader = None

# --- 2. MODEL DEFINITION ---
class PatchEmbed3D(nn.Module):
    def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
    def forward(self, x): return self.proj(x).flatten(2).transpose(1, 2)

class PrithviModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = PatchEmbed3D()
        self.register_buffer("pos_embed", torch.zeros(1, 3 * (224//16)**2, 768))
        enc_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, dim_feedforward=3072, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=12)
        in_dim = 768 * 3
        self.decoder = nn.Sequential(
            nn.Conv2d(in_dim, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(512, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 1, 1)
        )
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        if self.pos_embed.shape[1] == x.shape[1]: x = x + self.pos_embed
        x = self.encoder(x)
        x = x.transpose(1, 2).view(B, 768, 3, 14, 14).reshape(B, -1, 14, 14)
        return self.decoder(x)

# --- 3. ROBUST VISUALIZATION (Min-Max Fix) ---
def visualize_robust(model, loader, device, num_samples=5):
    print(f"\n Visualizing {num_samples} Samples (Robust Mode)...")
    model.eval()
    try: x, y = next(iter(loader))
    except: return

    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        preds = torch.sigmoid(model(x))
        masks = (preds > 0.5).float().cpu().numpy()

    x_np, y_np = x.cpu().numpy(), y.cpu().numpy()
    actual = min(num_samples, x_np.shape[0])

    # Scale figure height dynamically: 4 inches per sample
    fig, axs = plt.subplots(actual, 3, figsize=(15, 4 * actual))
    plt.suptitle(f"Prithvi Model: {actual} Test Samples", fontsize=16, y=1.02)

    for i in range(actual):
        # 1. Extract RGB (Red=2, Green=1, Blue=0)
        rgb = np.stack([x_np[i,2,1], x_np[i,1,1], x_np[i,0,1]], axis=2)

        # 2. Force Visibility (Min-Max Normalization)
        # Allows negative/normalized values to be seen as colors
        rgb_min, rgb_max = rgb.min(), rgb.max()
        if rgb_max - rgb_min > 0:
            rgb = (rgb - rgb_min) / (rgb_max - rgb_min)
        else:
            rgb = np.zeros_like(rgb)

        ax = axs[i] if actual > 1 else axs

        # Plot RGB
        ax[0].imshow(rgb)
        ax[0].set_title(f"Sample {i+1}: Input (RGB)")

        # Plot Ground Truth
        ax[1].imshow(y_np[i,0], cmap='gray')
        ax[1].set_title("Ground Truth")

        # Plot Prediction
        ax[2].imshow(masks[i,0], cmap='gray')
        ax[2].set_title("Prediction")

        for a in ax: a.axis('off')

    plt.tight_layout()
    plt.show()

# --- EXECUTION ---
model = PrithviModel().to(DEVICE)

# Priority: Best -> SWA -> Checkpoint
model_path = os.path.join(SAVE_DIR, "best_model.pth")
if not os.path.exists(model_path): model_path = os.path.join(SAVE_DIR, "swa_model_final.pth")
if not os.path.exists(model_path): model_path = os.path.join(SAVE_DIR, "checkpoint.pth")

if os.path.exists(model_path):
    print(f" Loading Weights from: {os.path.basename(model_path)}...", end=" ")
    try:
        ckpt = torch.load(model_path, map_location=DEVICE)
        sd = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
        new_sd = {k.replace('blocks', 'encoder.layers'): v for k,v in sd.items() if 'pos_embed' not in k or v.shape == model.pos_embed.shape}
        model.load_state_dict(new_sd, strict=False)
        print(" Loaded.")

        if val_loader:
            # CALL WITH 5 SAMPLES
            visualize_robust(model, val_loader, DEVICE, num_samples=5)
    except Exception as e:
        print(f" Failed to load weights: {e}")
else:
    print(" No valid model file found.")
 GENERATING 5 ROBUST SAMPLES FOR: /content/drive/MyDrive/Prithvi_PartialFT_Results/
 Dataset loaded successfully.
 Loading Weights from: best_model.pth...  Loaded.

 Visualizing 5 Samples (Robust Mode)...
No description has been provided for this image