In [ ]:
import os
import time
import shutil
import numpy as np
import rasterio
from rasterio.windows import from_bounds
import cv2
import ee
import geemap
from google.colab import drive
from sklearn.model_selection import train_test_split
try:
import geedim
except ImportError:
import subprocess
subprocess.check_call(["pip", "install", "geedim"])
subprocess.check_call(["pip", "install", "geemap"])
# 1. SETUP & AUTHENTICATION (Robust Fix)
drive.mount('/content/drive', force_remount=True)
def initialize_ee():
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print("Earth Engine Initialized Successfully.")
except Exception as e:
print(f"Initialization failed. Triggering Authentication... ({e})")
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print("Earth Engine Authenticated & Initialized.")
initialize_ee()
ASSET_ID = 'projects/[REDACTED_FOR_SECURITY]/assets/Punjab_Mask_2024_NEW'
SAVE_DIR = '/content/drive/MyDrive/Prithvi_PartialFT_Results/'
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
PATCH_SIZE = 224
S2_SCALE = 5000.0
TIME_WINDOWS = [
('2024-10-15', '2024-11-15'), # T1
('2025-01-01', '2025-01-31'), # T2
('2025-02-15', '2025-03-15') # T3
]
def generate_prithvi_npy():
print(f"Starting Data Generation. Project: [REDACTED_FOR_SECURITY]")
mask_img = ee.Image(ASSET_ID)
roi_geom = mask_img.geometry()
mask_file = 'local_mask_prithvi.tif'
if not os.path.exists(mask_file):
print("Downloading Mask...")
geemap.download_ee_image(mask_img, mask_file, region=roi_geom, scale=10, crs='EPSG:4326', overwrite=True)
with rasterio.open(mask_file) as src:
b = src.bounds
cx, cy = (b.left + b.right)/2, (b.bottom + b.top)/2
offset = 0.04
window = from_bounds(cx-offset, cy-offset, cx+offset, cy+offset, src.transform)
mask = src.read(1, window=window)
mask = np.where(mask > 0, 1.0, 0.0).astype(np.float32)
target_h, target_w = mask.shape
small_roi = ee.Geometry.BBox(cx-offset, cy-offset, cx+offset, cy+offset)
print(f"ROI: {target_h}x{target_w}")
stack = []
for i, (start, end) in enumerate(TIME_WINDOWS):
fname = f'prithvi_time_{i}.tif'
attempts = 0
while not os.path.exists(fname) and attempts < 3:
try:
print(f"Downloading T{i+1}: {start} to {end}...")
img = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \
.filterBounds(small_roi).filterDate(start, end).median() \
.select(['B2','B3','B4','B8','B11','B12'])
geemap.download_ee_image(img, fname, region=small_roi, scale=10, crs='EPSG:4326', overwrite=True)
except:
attempts += 1
time.sleep(2)
if not os.path.exists(fname):
if i > 0: shutil.copy(f'prithvi_time_{i-1}.tif', fname)
else:
with rasterio.open(mask_file) as src:
p = src.profile
p.update(count=6, dtype=rasterio.float32)
with rasterio.open(fname, 'w', **p) as dst: dst.write(np.zeros((6, target_h, target_w), dtype=np.float32))
with rasterio.open(fname) as src:
arr = src.read()
arr = np.transpose(arr, (1, 2, 0))
if arr.shape[:2] != (target_h, target_w):
arr = cv2.resize(arr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
arr = np.clip(arr / S2_SCALE, 0, 1).astype(np.float32)
stack.append(arr)
full_cube = np.stack(stack, axis=2)
x_out, y_out = [], []
stride = PATCH_SIZE
print("Tiling...")
for y in range(0, target_h, stride):
for x in range(0, target_w, stride):
img_p = full_cube[y:y+stride, x:x+stride]
mask_p = mask[y:y+stride, x:x+stride]
if img_p.shape[0] != PATCH_SIZE or img_p.shape[1] != PATCH_SIZE: continue
if np.mean(mask_p) < 0.01: continue
if np.isnan(img_p).any(): continue
x_out.append(img_p)
y_out.append(mask_p)
X = np.array(x_out, dtype=np.float32).transpose(0, 4, 3, 1, 2)
y = np.array(y_out, dtype=np.float32)[:, None, :, :]
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
print("Data Generation Complete.")
generate_prithvi_npy()
Mounted at /content/drive Initialization failed. Triggering Authentication... (Please authorize access to your Earth Engine account by running earthengine authenticate in your command line, or ee.Authenticate() in Python, and then retry.) Earth Engine Authenticated & Initialized. Starting Data Generation. Project: satmae-2026 Downloading Mask...
/usr/local/lib/python3.12/dist-packages/geemap/common.py:12471: FutureWarning: 'BaseImage' is deprecated and will be removed in a future release. Please use the 'ee.Image.gd' accessor instead. img = gd.download.BaseImage(image)
...tmae-2026/assets/Punjab_Mask_2024_NEW: 0%| |0/585 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:googleapiclient.http:Sleeping 0.82 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429 WARNING:googleapiclient.http:Sleeping 1.00 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429 WARNING:googleapiclient.http:Sleeping 1.81 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429 WARNING:googleapiclient.http:Sleeping 1.12 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429 WARNING:googleapiclient.http:Sleeping 1.85 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:googleapiclient.http:Sleeping 1.71 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:googleapiclient.http:Sleeping 1.30 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429 WARNING:googleapiclient.http:Sleeping 1.45 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/satmae-2026/thumbnails?fields=name&alt=json, after 429 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 /usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'projects/satmae-2026/assets/Punjab_Mask_2024_NEW'. return STACClient().get(self.id)
ROI: 891x891 Downloading T1: 2024-10-15 to 2024-11-15...
0%| |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 /usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'None'. return STACClient().get(self.id)
Downloading T2: 2025-01-01 to 2025-01-31...
0%| |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Downloading T3: 2025-02-15 to 2025-03-15...
0%| |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Tiling... Data Generation Complete.
In [ ]:
!pip install geedim
!pip install geemap
!pip install segmentation-models-pytorch
Requirement already satisfied: geedim in /usr/local/lib/python3.12/dist-packages (2.0.0) Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.12/dist-packages (from geedim) (2.0.2) Requirement already satisfied: rasterio>=1.3.8 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.0) Requirement already satisfied: click>=8 in /usr/local/lib/python3.12/dist-packages (from geedim) (8.3.1) Requirement already satisfied: tqdm>=4.6 in /usr/local/lib/python3.12/dist-packages (from geedim) (4.67.1) Requirement already satisfied: earthengine-api>=0.1.379 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.24) Requirement already satisfied: tabulate>=0.9 in /usr/local/lib/python3.12/dist-packages (from geedim) (0.9.0) Requirement already satisfied: fsspec>=2025.2 in /usr/local/lib/python3.12/dist-packages (from geedim) (2025.3.0) Requirement already satisfied: aiohttp>=3.11 in /usr/local/lib/python3.12/dist-packages (from geedim) (3.13.3) Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (2.6.1) Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.4.0) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (25.4.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.8.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (6.7.0) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (0.4.1) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.22.0) Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (3.7.0) Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.187.0) Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.43.0) Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.3.0) Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.31.0) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.32.4) Requirement already satisfied: affine in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2.4.0) Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2026.1.4) Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (0.7.2) Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (3.3.1) Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.12/dist-packages (from aiosignal>=1.4.0->aiohttp>=3.11->geedim) (4.15.0) Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (2.29.0) Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (4.2.0) Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (6.2.4) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.4.2) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (4.9.1) Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.12/dist-packages (from yarl<2.0,>=1.17.0->aiohttp>=3.11->geedim) (3.11) Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.5.0) Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.8.0) Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (1.8.0) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (3.4.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (2.5.0) Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.72.0) Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (5.29.5) Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.27.0) Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.6.1) Requirement already satisfied: geemap in /usr/local/lib/python3.12/dist-packages (0.35.3) Requirement already satisfied: bqplot in /usr/local/lib/python3.12/dist-packages (from geemap) (0.12.45) Requirement already satisfied: colour in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.5) Requirement already satisfied: earthengine-api>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (1.5.24) Requirement already satisfied: eerepr>=0.1.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.2) Requirement already satisfied: folium>=0.17.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0) Requirement already satisfied: geocoder in /usr/local/lib/python3.12/dist-packages (from geemap) (1.38.1) Requirement already satisfied: ipyevents in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.4) Requirement already satisfied: ipyfilechooser>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.6.0) Requirement already satisfied: ipyleaflet>=0.19.2 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0) Requirement already satisfied: ipytree in /usr/local/lib/python3.12/dist-packages (from geemap) (0.2.2) Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from geemap) (3.10.0) Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.2) Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from geemap) (2.2.2) Requirement already satisfied: plotly in /usr/local/lib/python3.12/dist-packages (from geemap) (5.24.1) Requirement already satisfied: pyperclip in /usr/local/lib/python3.12/dist-packages (from geemap) (1.11.0) Requirement already satisfied: pyshp>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from geemap) (3.0.3) Requirement already satisfied: python-box in /usr/local/lib/python3.12/dist-packages (from geemap) (7.3.2) Requirement already satisfied: scooby in /usr/local/lib/python3.12/dist-packages (from geemap) (0.11.0) Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (3.7.0) Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.187.0) Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.43.0) Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.3.0) Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.31.0) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.32.4) Requirement already satisfied: branca>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (0.8.2) Requirement already satisfied: jinja2>=2.9 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (3.1.6) Requirement already satisfied: xyzservices in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (2025.11.0) Requirement already satisfied: ipywidgets in /usr/local/lib/python3.12/dist-packages (from ipyfilechooser>=0.6.0->geemap) (7.7.1) Requirement already satisfied: jupyter-leaflet<0.21,>=0.20 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.20.0) Requirement already satisfied: traittypes<3,>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.2.3) Requirement already satisfied: traitlets>=4.3.0 in /usr/local/lib/python3.12/dist-packages (from bqplot->geemap) (5.7.1) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.3) Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (8.3.1) Requirement already satisfied: future in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.0.0) Requirement already satisfied: ratelim in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (0.1.6) Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.17.0) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.3.3) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (4.61.1) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.4.9) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (25.0) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (11.3.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (3.3.1) Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.12/dist-packages (from plotly->geemap) (9.1.2) Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (2.29.0) Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (4.2.0) Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (6.2.4) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.4.2) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (4.9.1) Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.17.1) Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0) Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.6.10) Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.34.0) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.16) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2>=2.9->folium>=0.17.0->geemap) (3.0.3) Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.5.0) Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.8.0) Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (1.8.0) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.4.4) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.11) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2026.1.4) Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from ratelim->geocoder->geemap) (4.4.2) Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.72.0) Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (5.29.5) Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.27.0) Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.8.15) Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.4.9) Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.1) Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.6.0) Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.5) Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (26.2.1) Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.1) Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (75.2.0) Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.19.2) Requirement already satisfied: pickleshare in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.5) Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.52) Requirement already satisfied: pygments in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.19.2) Requirement already satisfied: backcall in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0) Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.9.0) Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.6.1) Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.12/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.7) Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.12/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.8.5) Requirement already satisfied: entrypoints in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.4) Requirement already satisfied: jupyter-core>=4.9.2 in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.1) Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0) Requirement already satisfied: nbformat in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.10.4) Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.16.6) Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0) Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.18.1) Requirement already satisfied: prometheus-client in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.23.1) Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.3) Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.12/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.12/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.14) Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.12/dist-packages (from jupyter-core>=4.9.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.5.1) Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.12/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.4) Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.13.5) Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.3.0) Requirement already satisfied: defusedxml in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.1) Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.3.0) Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.2.0) Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.10.4) Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1) Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.21.2) Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.26.0) Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.12/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0) Requirement already satisfied: webencodings in /usr/local/lib/python3.12/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.1) Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0) Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.4.0) Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2025.9.1) Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.37.0) Requirement already satisfied: rpds-py>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.30.0) Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.12/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.14.0) Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0) Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.8.1) Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.15.0) Requirement already satisfied: pycparser in /usr/local/lib/python3.12/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.23) Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.12.1) Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.12.0) Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.3) Requirement already satisfied: overrides>=5.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.7.0) Requirement already satisfied: websocket-client>=1.7 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.9.0) Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.0.0) Requirement already satisfied: pyyaml>=5.3 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.0.3) Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.4) Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.1) Requirement already satisfied: fqdn in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1) Requirement already satisfied: isoduration in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (20.11.0) Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.0) Requirement already satisfied: rfc3987-syntax>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.1.0) Requirement already satisfied: uri-template in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.0) Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.10.0) Requirement already satisfied: lark>=1.2.2 in /usr/local/lib/python3.12/dist-packages (from rfc3987-syntax>=1.1.0->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.1) Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.12/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0) Collecting segmentation-models-pytorch Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB) Requirement already satisfied: huggingface-hub>=0.24 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.36.0) Requirement already satisfied: numpy>=1.19.3 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (2.0.2) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (11.3.0) Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.7.0) Requirement already satisfied: timm>=0.9 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (1.0.24) Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (2.9.0+cu126) Requirement already satisfied: torchvision>=0.9 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.24.0+cu126) Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (4.67.1) Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (3.20.2) Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (2025.3.0) Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (25.0) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (6.0.3) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (2.32.4) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (4.15.0) Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (1.2.0) Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (75.2.0) Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (1.14.0) Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.6.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.1.6) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.80) Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (9.10.2.21) Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.4.1) Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (11.3.0.4) Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (10.3.7.77) Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (11.7.1.2) Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.5.4.2) Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (0.7.1) Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (2.27.5) Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.3.20) Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77) Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.85) Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (1.11.1.6) Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.5.0) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.8->segmentation-models-pytorch) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.8->segmentation-models-pytorch) (3.0.3) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (3.4.4) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (3.11) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (2026.1.4) Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.8/154.8 kB 6.4 MB/s eta 0:00:00 Installing collected packages: segmentation-models-pytorch Successfully installed segmentation-models-pytorch-0.5.0
In [ ]:
import os
import requests
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
SAVE_PATH = "/content/drive/MyDrive/Prithvi_100M.pt"
DOWNLOAD_URL = "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M/resolve/main/Prithvi_100M.pt"
if not os.path.exists(SAVE_PATH):
print(f"Downloading weights to: {SAVE_PATH}")
try:
response = requests.get(DOWNLOAD_URL, stream=True)
response.raise_for_status()
with open(SAVE_PATH, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Download Complete.")
except Exception as e:
print(f"Download Failed: {e}")
else:
print(f"Weights exist at: {SAVE_PATH}")
Mounted at /content/drive Downloading weights to: /content/drive/MyDrive/Prithvi_100M.pt Download Complete.
In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import distance_transform_edt as distance
import segmentation_models_pytorch as smp
def apply_training_augmentation(x, y):
if np.random.rand() > 0.5:
x = torch.flip(x, [4])
y = torch.flip(y, [3])
if np.random.rand() > 0.5:
x = torch.flip(x, [3])
y = torch.flip(y, [2])
k = np.random.randint(0, 4)
x = torch.rot90(x, k, [3, 4])
y = torch.rot90(y, k, [2, 3])
if np.random.rand() > 0.5:
noise = (torch.rand(x.shape[0], 1, 1, 1, 1, device=x.device) * 0.2) + 0.9
x = x * noise
return x.contiguous(), y.contiguous()
class TestTimeAugmentation(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
pred_orig = self.model(x)
x_hflip = torch.flip(x, [4])
pred_hflip = torch.flip(self.model(x_hflip), [3])
x_vflip = torch.flip(x, [3])
pred_vflip = torch.flip(self.model(x_vflip), [2])
x_rot = torch.rot90(x, 1, [3, 4])
pred_rot = torch.rot90(self.model(x_rot), -1, [2, 3])
return torch.stack([pred_orig, pred_hflip, pred_vflip, pred_rot]).mean(dim=0)
def validate_with_tta(model, loader, criterion, device):
model.eval()
total_loss = 0
tta_model = TestTimeAugmentation(model)
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
preds = tta_model(x)
loss = criterion(preds, y)
total_loss += loss.item()
return total_loss / len(loader)
class HausdorffDTLoss(nn.Module):
def __init__(self, alpha=2.0):
super().__init__()
self.alpha = alpha
def forward(self, pred, gt):
with torch.no_grad():
gt_np = gt.cpu().numpy()
dist_map = np.zeros_like(gt_np)
for i in range(len(gt_np)):
mask = (gt_np[i, 0] > 0.5).astype(np.uint8)
if mask.sum() == 0:
dist_map[i, 0] = np.ones_like(mask) * 100.0
continue
d_in = distance(mask)
d_out = distance(1 - mask)
dist_map[i, 0] = (d_out - d_in)
dist_map = torch.tensor(dist_map, device=pred.device, dtype=torch.float32)
probs = torch.sigmoid(pred)
return torch.mean((probs - gt) ** 2 * (1 + self.alpha * torch.abs(dist_map)))
class CompoundLoss(nn.Module):
def __init__(self):
super().__init__()
self.dice = smp.losses.DiceLoss(mode='binary', from_logits=True)
self.hausdorff = HausdorffDTLoss(alpha=2.0)
def forward(self, p, t):
return 0.7 * self.dice(p, t) + 0.3 * self.hausdorff(p, t)
In [ ]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
# CONFIGURATION (For Model Init)
CONFIG = {
"EPOCHS": 5000,
"PATIENCE": 200,
"BATCH_SIZE": 8,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 0.05,
"WARMUP_EPOCHS": 20,
"SWA_START_EPOCH": None,
"RESUME": True,
"DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
"IMG_SIZE": 224,
"NUM_FRAMES": 3,
"IN_CHANS": 6,
"EMBED_DIM": 768,
"DEPTH": 12,
"NUM_HEADS": 12,
"PATCH_SIZE": 16,
"PRETRAINED_PATH": "/content/drive/MyDrive/Prithvi_100M.pt",
"SAVE_DIR": SAVE_DIR
}
class PunjabWheatDataset(Dataset):
def __init__(self, x_path, y_path):
self.data = np.load(x_path, mmap_mode='r')
self.masks = np.load(y_path, mmap_mode='r')
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img = self.data[idx]
mask = self.masks[idx]
return torch.from_numpy(img.copy()).float(), torch.from_numpy(mask.copy()).float()
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega
pos = pos.reshape(-1)
out = np.einsum('m,d->md', pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size):
grid_h = np.arange(grid_size, dtype=np.float32)
pos_embed_h = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_h)
grid_w = np.arange(grid_size, dtype=np.float32)
pos_embed_w = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_w)
grid_t = np.arange(t_size, dtype=np.float32)
pos_embed_t = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
pos_t = pos_embed_t[:, np.newaxis, np.newaxis, :]
pos_h = pos_embed_h[np.newaxis, :, np.newaxis, :]
pos_w = pos_embed_w[np.newaxis, np.newaxis, :, :]
pos_embed = pos_t + pos_h + pos_w
pos_embed = pos_embed.reshape(-1, embed_dim)
return torch.from_numpy(pos_embed).float().unsqueeze(0)
class PatchEmbed3D(nn.Module):
def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
super().__init__()
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class PrithviPartialFT(nn.Module):
def __init__(self):
super().__init__()
self.patch_embed = PatchEmbed3D(patch_size=CONFIG["PATCH_SIZE"], frames=CONFIG["NUM_FRAMES"], in_chans=CONFIG["IN_CHANS"], embed_dim=CONFIG["EMBED_DIM"])
self.register_buffer("pos_embed", get_3d_sincos_pos_embed(CONFIG["EMBED_DIM"], CONFIG["IMG_SIZE"] // CONFIG["PATCH_SIZE"], CONFIG["NUM_FRAMES"]))
encoder_layer = nn.TransformerEncoderLayer(d_model=CONFIG["EMBED_DIM"], nhead=CONFIG["NUM_HEADS"], dim_feedforward=CONFIG["EMBED_DIM"]*4, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=CONFIG["DEPTH"])
in_dim = CONFIG["EMBED_DIM"] * CONFIG["NUM_FRAMES"]
self.decoder = nn.Sequential(
nn.Conv2d(in_dim, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(64, 1, kernel_size=1)
)
self._init_decoder()
def _init_decoder(self):
for m in self.decoder.modules():
if isinstance(m, (nn.Conv2d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.constant_(m.bias, 0)
print("Decoder initialized.")
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.encoder(x)
H_p = CONFIG["IMG_SIZE"] // CONFIG["PATCH_SIZE"]
x = x.transpose(1, 2).view(B, CONFIG["EMBED_DIM"], CONFIG["NUM_FRAMES"], H_p, H_p)
x = x.reshape(B, -1, H_p, H_p)
return self.decoder(x)
In [6]:
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
train_x = os.path.join(SAVE_DIR, 'train_x.npy')
train_y = os.path.join(SAVE_DIR, 'train_y.npy')
val_x = os.path.join(SAVE_DIR, 'val_x.npy')
val_y = os.path.join(SAVE_DIR, 'val_y.npy')
if not os.path.exists(train_x):
raise FileNotFoundError(f"Data not found in {SAVE_DIR}. Run Cell 1 first.")
print("Loading Data...")
train_loader = DataLoader(PunjabWheatDataset(train_x, train_y), batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=2)
val_loader = DataLoader(PunjabWheatDataset(val_x, val_y), batch_size=CONFIG["BATCH_SIZE"], shuffle=False, num_workers=2)
model = PrithviPartialFT().to(CONFIG["DEVICE"])
# --- WEIGHT LOADING ---
if os.path.exists(CONFIG["PRETRAINED_PATH"]):
print(f"Loading Weights from {CONFIG['PRETRAINED_PATH']}...")
ckpt = torch.load(CONFIG["PRETRAINED_PATH"], map_location='cpu')
state_dict = ckpt['model'] if 'model' in ckpt else ckpt
new_state_dict = {}
for k, v in state_dict.items():
if 'blocks' in k:
new_key = k.replace('blocks', 'encoder.layers')
new_state_dict[new_key] = v
elif 'patch_embed' in k: new_state_dict[k] = v
elif 'pos_embed' in k:
if v.shape == model.pos_embed.shape: new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
print("Pretrained Weights Loaded.")
else:
print("WARNING: PRETRAINED WEIGHTS NOT FOUND.")
# --- PARTIAL FREEZING STRATEGY (10 Frozen / 2 Unfrozen) ---
print("\nConfiguring Layers...")
# 1. Train Embeddings
for param in model.patch_embed.parameters(): param.requires_grad = True
print("- Embeddings: LEARNABLE")
# 2. Configure Encoder (0-9 Frozen, 10-11 Trainable)
for i, layer in enumerate(model.encoder.layers):
if i < 10:
for param in layer.parameters(): param.requires_grad = False
else:
for param in layer.parameters(): param.requires_grad = True
print(f"- Encoder Layers 0-9: FROZEN")
print(f"- Encoder Layers 10-11: LEARNABLE")
# 3. Train Decoder
for param in model.decoder.parameters(): param.requires_grad = True
print("- Decoder: LEARNABLE\n")
criterion = CompoundLoss().to(CONFIG["DEVICE"])
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.AdamW(trainable_params, lr=CONFIG["LEARNING_RATE"], weight_decay=CONFIG["WEIGHT_DECAY"])
scheduler_warmup = LinearLR(optimizer, start_factor=0.01, total_iters=CONFIG["WARMUP_EPOCHS"])
scheduler_cosine = CosineAnnealingLR(optimizer, T_max=CONFIG["EPOCHS"] - CONFIG["WARMUP_EPOCHS"], eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_cosine], milestones=[CONFIG["WARMUP_EPOCHS"]])
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
swa_active = False
checkpoint_path = os.path.join(CONFIG["SAVE_DIR"], "checkpoint.pth")
start_epoch = 0
best_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': []}
if CONFIG["RESUME"] and os.path.exists(checkpoint_path):
print("Found checkpoint. Loading...")
try:
ckpt = torch.load(checkpoint_path)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
start_epoch = ckpt['epoch'] + 1
best_loss = ckpt.get('best_loss', float('inf'))
history = ckpt.get('history', {'train_loss': [], 'val_loss': []})
if 'swa_model' in ckpt: swa_model.load_state_dict(ckpt['swa_model'])
if 'swa_active' in ckpt: swa_active = ckpt['swa_active']
print(f"Resumed from Epoch {start_epoch}")
except:
start_epoch = 0
print("Starting Partial Fine-Tuning...")
for epoch in range(start_epoch, CONFIG["EPOCHS"]):
model.train()
train_loss = 0
for x, y in train_loader:
x, y = x.to(CONFIG["DEVICE"]), y.to(CONFIG["DEVICE"])
x, y = apply_training_augmentation(x, y)
optimizer.zero_grad()
preds = model(x)
loss = criterion(preds, y.contiguous())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = validate_with_tta(model, val_loader, criterion, CONFIG["DEVICE"])
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
SWA_TRIGGER = CONFIG["PATIENCE"] // 2
if patience_counter >= SWA_TRIGGER and not swa_active:
print(f"Stagnation (Patience {patience_counter}). Activating SWA.")
swa_active = True
swa_model.update_parameters(model)
if swa_active:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
checkpoint_dict = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
'swa_model': swa_model.state_dict(),
'history': history,
'swa_active': swa_active
}
torch.save(checkpoint_dict, checkpoint_path)
if epoch % 10 == 0:
torch.save(checkpoint_dict, os.path.join(CONFIG["SAVE_DIR"], f"checkpoint_epoch_{epoch}.pth"))
if avg_val_loss < best_loss:
best_loss = avg_val_loss
torch.save(model.state_dict(), os.path.join(CONFIG["SAVE_DIR"], "best_model.pth"))
patience_counter = 0
print(f"New Best Model! Loss: {best_loss:.4f}")
else:
patience_counter += 1
if patience_counter >= CONFIG["PATIENCE"]:
print(f"Early Stopping at Epoch {epoch}")
if swa_active:
torch.save(swa_model.state_dict(), os.path.join(CONFIG["SAVE_DIR"], "swa_model_final.pth"))
break
if epoch % 5 == 0:
mode = "SWA" if swa_active else "STD"
print(f"Ep {epoch} [{mode}] | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
Loading Data...
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True warnings.warn(
Decoder initialized. Loading Weights from /content/drive/MyDrive/Prithvi_100M.pt... Pretrained Weights Loaded. Configuring Layers... - Embeddings: LEARNABLE - Encoder Layers 0-9: FROZEN - Encoder Layers 10-11: LEARNABLE - Decoder: LEARNABLE Starting Partial Fine-Tuning... New Best Model! Loss: 0.9044 Ep 0 [STD] | Train: 1.2669 | Val: 0.9044 New Best Model! Loss: 0.8892 New Best Model! Loss: 0.8671 New Best Model! Loss: 0.8527 Ep 5 [STD] | Train: 0.8977 | Val: 0.8527 New Best Model! Loss: 0.8485 Ep 10 [STD] | Train: 0.7293 | Val: 0.8817 Ep 15 [STD] | Train: 0.6343 | Val: 0.8854 Ep 20 [STD] | Train: 0.5313 | Val: 0.8954 Ep 25 [STD] | Train: 0.4535 | Val: 0.8913 New Best Model! Loss: 0.7621 Ep 30 [STD] | Train: 0.4097 | Val: 0.7621 New Best Model! Loss: 0.6572 New Best Model! Loss: 0.5596 New Best Model! Loss: 0.4836 New Best Model! Loss: 0.4594 New Best Model! Loss: 0.4345 Ep 35 [STD] | Train: 0.3590 | Val: 0.4345 New Best Model! Loss: 0.4039 New Best Model! Loss: 0.3933 New Best Model! Loss: 0.3867 New Best Model! Loss: 0.3665 Ep 40 [STD] | Train: 0.3295 | Val: 0.3665 New Best Model! Loss: 0.3521 New Best Model! Loss: 0.3386 New Best Model! Loss: 0.3289 New Best Model! Loss: 0.3194 New Best Model! Loss: 0.2993 Ep 45 [STD] | Train: 0.3180 | Val: 0.2993 New Best Model! Loss: 0.2901 New Best Model! Loss: 0.2849 New Best Model! Loss: 0.2825 New Best Model! Loss: 0.2811 New Best Model! Loss: 0.2791 Ep 50 [STD] | Train: 0.3030 | Val: 0.2791 New Best Model! Loss: 0.2775 New Best Model! Loss: 0.2708 Ep 55 [STD] | Train: 0.2963 | Val: 0.2708 New Best Model! Loss: 0.2661 New Best Model! Loss: 0.2647 Ep 60 [STD] | Train: 0.2874 | Val: 0.2821 New Best Model! Loss: 0.2644 New Best Model! Loss: 0.2604 Ep 65 [STD] | Train: 0.2818 | Val: 0.2604 New Best Model! Loss: 0.2591 New Best Model! Loss: 0.2589 Ep 70 [STD] | Train: 0.2744 | Val: 0.2594 New Best Model! Loss: 0.2580 New Best Model! Loss: 0.2569 New Best Model! Loss: 0.2552 New Best Model! Loss: 0.2530 Ep 75 [STD] | Train: 0.2693 | Val: 0.2530 New Best Model! Loss: 0.2527 Ep 80 [STD] | Train: 0.2625 | Val: 0.2564 New Best Model! Loss: 0.2519 Ep 85 [STD] | Train: 0.2546 | Val: 0.2519 New Best Model! Loss: 0.2496 New Best Model! Loss: 0.2483 Ep 90 [STD] | Train: 0.2598 | Val: 0.2483 New Best Model! Loss: 0.2459 New Best Model! Loss: 0.2455 Ep 95 [STD] | Train: 0.2572 | Val: 0.2512 Ep 100 [STD] | Train: 0.2523 | Val: 0.2491 New Best Model! Loss: 0.2450 New Best Model! Loss: 0.2414 New Best Model! Loss: 0.2394 New Best Model! Loss: 0.2392 Ep 105 [STD] | Train: 0.2429 | Val: 0.2396 New Best Model! Loss: 0.2381 New Best Model! Loss: 0.2364 New Best Model! Loss: 0.2351 Ep 110 [STD] | Train: 0.2384 | Val: 0.2351 Ep 115 [STD] | Train: 0.2311 | Val: 0.2455 Ep 120 [STD] | Train: 0.2300 | Val: 0.2380 New Best Model! Loss: 0.2332 New Best Model! Loss: 0.2316 New Best Model! Loss: 0.2307 New Best Model! Loss: 0.2289 Ep 125 [STD] | Train: 0.2294 | Val: 0.2289 New Best Model! Loss: 0.2224 New Best Model! Loss: 0.2211 New Best Model! Loss: 0.2199 Ep 130 [STD] | Train: 0.2274 | Val: 0.2258 New Best Model! Loss: 0.2185 New Best Model! Loss: 0.2164 Ep 135 [STD] | Train: 0.2221 | Val: 0.2185 New Best Model! Loss: 0.2156 New Best Model! Loss: 0.2137 New Best Model! Loss: 0.2130 Ep 140 [STD] | Train: 0.2176 | Val: 0.2130 Ep 145 [STD] | Train: 0.2164 | Val: 0.2162 New Best Model! Loss: 0.2126 New Best Model! Loss: 0.2093 New Best Model! Loss: 0.2086 New Best Model! Loss: 0.2085 Ep 150 [STD] | Train: 0.2061 | Val: 0.2085 Ep 155 [STD] | Train: 0.2091 | Val: 0.2111 Ep 160 [STD] | Train: 0.1939 | Val: 0.2137 New Best Model! Loss: 0.2065 Ep 165 [STD] | Train: 0.1871 | Val: 0.2071 New Best Model! Loss: 0.2017 Ep 170 [STD] | Train: 0.2006 | Val: 0.2022 New Best Model! Loss: 0.2001 New Best Model! Loss: 0.1977 New Best Model! Loss: 0.1971 Ep 175 [STD] | Train: 0.1950 | Val: 0.1983 New Best Model! Loss: 0.1970 New Best Model! Loss: 0.1961 Ep 180 [STD] | Train: 0.1871 | Val: 0.1969 Ep 185 [STD] | Train: 0.1846 | Val: 0.2119 New Best Model! Loss: 0.1956 Ep 190 [STD] | Train: 0.1834 | Val: 0.1956 New Best Model! Loss: 0.1952 Ep 195 [STD] | Train: 0.1749 | Val: 0.1953 New Best Model! Loss: 0.1929 New Best Model! Loss: 0.1917 New Best Model! Loss: 0.1908 Ep 200 [STD] | Train: 0.1779 | Val: 0.1954 New Best Model! Loss: 0.1907 New Best Model! Loss: 0.1889 Ep 205 [STD] | Train: 0.1723 | Val: 0.1889 Ep 210 [STD] | Train: 0.1702 | Val: 0.1936 New Best Model! Loss: 0.1882 Ep 215 [STD] | Train: 0.1707 | Val: 0.1888 Ep 220 [STD] | Train: 0.1767 | Val: 0.1896 New Best Model! Loss: 0.1881 Ep 225 [STD] | Train: 0.1666 | Val: 0.1907 Ep 230 [STD] | Train: 0.1664 | Val: 0.1899 Ep 235 [STD] | Train: 0.1621 | Val: 0.1882 Ep 240 [STD] | Train: 0.1656 | Val: 0.1930 New Best Model! Loss: 0.1876 Ep 245 [STD] | Train: 0.1578 | Val: 0.1876 New Best Model! Loss: 0.1875 New Best Model! Loss: 0.1868 New Best Model! Loss: 0.1857 Ep 250 [STD] | Train: 0.1595 | Val: 0.1865 New Best Model! Loss: 0.1856 New Best Model! Loss: 0.1840 New Best Model! Loss: 0.1839 Ep 255 [STD] | Train: 0.1670 | Val: 0.1839 New Best Model! Loss: 0.1838 Ep 260 [STD] | Train: 0.1532 | Val: 0.1856 Ep 265 [STD] | Train: 0.1558 | Val: 0.1902 Ep 270 [STD] | Train: 0.1556 | Val: 0.1878 New Best Model! Loss: 0.1834 New Best Model! Loss: 0.1815 New Best Model! Loss: 0.1811 Ep 275 [STD] | Train: 0.1631 | Val: 0.1826 Ep 280 [STD] | Train: 0.1582 | Val: 0.1880 Ep 285 [STD] | Train: 0.1577 | Val: 0.1866 Ep 290 [STD] | Train: 0.1443 | Val: 0.1837 Ep 295 [STD] | Train: 0.1539 | Val: 0.1822 Ep 300 [STD] | Train: 0.1479 | Val: 0.1852 New Best Model! Loss: 0.1804 Ep 305 [STD] | Train: 0.1388 | Val: 0.1804 Ep 310 [STD] | Train: 0.1367 | Val: 0.1822 New Best Model! Loss: 0.1792 Ep 315 [STD] | Train: 0.1495 | Val: 0.1800 Ep 320 [STD] | Train: 0.1447 | Val: 0.1862 Ep 325 [STD] | Train: 0.1574 | Val: 0.1858 Ep 330 [STD] | Train: 0.1428 | Val: 0.1849 New Best Model! Loss: 0.1790 New Best Model! Loss: 0.1787 Ep 335 [STD] | Train: 0.1569 | Val: 0.1794 Ep 340 [STD] | Train: 0.1451 | Val: 0.1885 Ep 345 [STD] | Train: 0.1498 | Val: 0.1846 Ep 350 [STD] | Train: 0.1416 | Val: 0.1809 Ep 355 [STD] | Train: 0.1403 | Val: 0.1805 Ep 360 [STD] | Train: 0.1393 | Val: 0.1824 Ep 365 [STD] | Train: 0.1460 | Val: 0.1832 New Best Model! Loss: 0.1776 Ep 370 [STD] | Train: 0.1408 | Val: 0.1833 Ep 375 [STD] | Train: 0.1400 | Val: 0.1803 Ep 380 [STD] | Train: 0.1448 | Val: 0.1800 New Best Model! Loss: 0.1772 Ep 385 [STD] | Train: 0.1351 | Val: 0.1790 Ep 390 [STD] | Train: 0.1364 | Val: 0.1805 Ep 395 [STD] | Train: 0.1383 | Val: 0.1795 Ep 400 [STD] | Train: 0.1338 | Val: 0.1814 Ep 405 [STD] | Train: 0.1334 | Val: 0.1793 Ep 410 [STD] | Train: 0.1354 | Val: 0.1809 Ep 415 [STD] | Train: 0.1381 | Val: 0.1807 Ep 420 [STD] | Train: 0.1360 | Val: 0.1811 Ep 425 [STD] | Train: 0.1316 | Val: 0.1816 Ep 430 [STD] | Train: 0.1288 | Val: 0.1805 Ep 435 [STD] | Train: 0.1263 | Val: 0.1788 Ep 440 [STD] | Train: 0.1309 | Val: 0.1832 Ep 445 [STD] | Train: 0.1347 | Val: 0.1804 New Best Model! Loss: 0.1759 New Best Model! Loss: 0.1746 Ep 450 [STD] | Train: 0.1272 | Val: 0.1758 Ep 455 [STD] | Train: 0.1279 | Val: 0.1813 Ep 460 [STD] | Train: 0.1230 | Val: 0.1787 Ep 465 [STD] | Train: 0.1238 | Val: 0.1794 Ep 470 [STD] | Train: 0.1258 | Val: 0.1808 Ep 475 [STD] | Train: 0.1214 | Val: 0.1805 Ep 480 [STD] | Train: 0.1200 | Val: 0.1786 Ep 485 [STD] | Train: 0.1246 | Val: 0.1767 Ep 490 [STD] | Train: 0.1240 | Val: 0.1785 Ep 495 [STD] | Train: 0.1197 | Val: 0.1781 Ep 500 [STD] | Train: 0.1213 | Val: 0.1764 Ep 505 [STD] | Train: 0.1184 | Val: 0.1781 Ep 510 [STD] | Train: 0.1207 | Val: 0.1758 New Best Model! Loss: 0.1745 Ep 515 [STD] | Train: 0.1207 | Val: 0.1745 Ep 520 [STD] | Train: 0.1215 | Val: 0.1752 Ep 525 [STD] | Train: 0.1198 | Val: 0.1827 Ep 530 [STD] | Train: 0.1175 | Val: 0.1755 Ep 535 [STD] | Train: 0.1186 | Val: 0.1761 New Best Model! Loss: 0.1743 Ep 540 [STD] | Train: 0.1228 | Val: 0.1743 Ep 545 [STD] | Train: 0.1220 | Val: 0.1769 New Best Model! Loss: 0.1741 Ep 550 [STD] | Train: 0.1220 | Val: 0.1784 New Best Model! Loss: 0.1730 Ep 555 [STD] | Train: 0.1225 | Val: 0.1779 Ep 560 [STD] | Train: 0.1215 | Val: 0.1758 Ep 565 [STD] | Train: 0.1180 | Val: 0.1781 Ep 570 [STD] | Train: 0.1140 | Val: 0.1755 Ep 575 [STD] | Train: 0.1156 | Val: 0.1758
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) /tmp/ipython-input-3277825797.py in <cell line: 0>() 139 140 if epoch % 10 == 0: --> 141 torch.save(checkpoint_dict, os.path.join(CONFIG["SAVE_DIR"], f"checkpoint_epoch_{epoch}.pth")) 142 143 if avg_val_loss < best_loss: /usr/local/lib/python3.12/dist-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record) 965 if _use_new_zipfile_serialization: 966 with _open_zipfile_writer(f) as opened_zipfile: --> 967 _save( 968 obj, 969 opened_zipfile, /usr/local/lib/python3.12/dist-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record) 1266 storage = storage.cpu() 1267 # Now that it is on the CPU we can directly copy it into the zip file -> 1268 zip_file.write_record(name, storage, num_bytes) 1269 1270 KeyboardInterrupt:
In [8]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import os
import json
import time
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import directed_hausdorff
from scipy.ndimage import binary_dilation
# ==========================================
# CONFIGURATION
# ==========================================
SAVE_DIR = '/content/drive/MyDrive/Prithvi_PartialFT_Results/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================================
print(f" STARTING EVALUATION FOR: {SAVE_DIR}")
# --- 1. DATASET LOADING (Auto-Fix) ---
class PunjabWheatDataset(Dataset):
def __init__(self, x_path, y_path):
self.data = np.load(x_path, mmap_mode='r')
self.masks = np.load(y_path, mmap_mode='r')
def __len__(self): return len(self.data)
def __getitem__(self, idx):
return torch.from_numpy(self.data[idx].copy()).float(), torch.from_numpy(self.masks[idx].copy()).float()
val_x_path = os.path.join(SAVE_DIR, 'val_x.npy')
val_y_path = os.path.join(SAVE_DIR, 'val_y.npy')
if os.path.exists(val_x_path):
print(" Dataset loaded successfully.")
val_ds = PunjabWheatDataset(val_x_path, val_y_path)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=True)
else:
print(" Dataset not found. Cannot run visualization.")
val_loader = None
# --- 2. MODEL DEFINITION ---
class PatchEmbed3D(nn.Module):
def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
super().__init__()
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, x): return self.proj(x).flatten(2).transpose(1, 2)
class PrithviPartialFT(nn.Module):
def __init__(self):
super().__init__()
self.patch_embed = PatchEmbed3D()
self.register_buffer("pos_embed", torch.zeros(1, 3 * (224//16)**2, 768))
enc_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, dim_feedforward=3072, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=12)
in_dim = 768 * 3
self.decoder = nn.Sequential(
nn.Conv2d(in_dim, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(512, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(64, 1, 1)
)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
if self.pos_embed.shape[1] == x.shape[1]: x = x + self.pos_embed
x = self.encoder(x)
x = x.transpose(1, 2).view(B, 768, 3, 14, 14).reshape(B, -1, 14, 14)
return self.decoder(x)
# --- 3. VISUALIZATION (True RGB) ---
def visualize_rgb(model, loader, device, num_samples=3):
print(f"\n Generating {num_samples} Visualizations (RGB)...")
model.eval()
try: x, y = next(iter(loader))
except: return
x, y = x.to(device), y.to(device)
with torch.no_grad():
preds = torch.sigmoid(model(x))
masks = (preds > 0.5).float().cpu().numpy()
x_np, y_np = x.cpu().numpy(), y.cpu().numpy()
actual = min(num_samples, x_np.shape[0])
fig, axs = plt.subplots(actual, 3, figsize=(15, 5 * actual))
plt.suptitle("Input (RGB) | Ground Truth | Prediction", fontsize=16)
for i in range(actual):
# RGB Composite: Red(2), Green(1), Blue(0)
rgb = np.stack([x_np[i,2,1], x_np[i,1,1], x_np[i,0,1]], axis=2)
p2, p98 = np.percentile(rgb, (2, 98))
rgb = np.clip((rgb - p2) / (p98 - p2 + 1e-6), 0, 1)
ax = axs[i] if actual > 1 else axs
ax[0].imshow(rgb); ax[0].set_title(f"Input Sample {i+1}")
ax[1].imshow(y_np[i,0], cmap='gray'); ax[1].set_title("Ground Truth")
ax[2].imshow(masks[i,0], cmap='gray'); ax[2].set_title("Prediction")
for a in ax: a.axis('off')
plt.tight_layout(); plt.show()
# --- 4. DEEP METRICS ---
def evaluate_metrics(loader, model, device):
print("\n Calculating Deep Metrics...")
model.eval()
tp, fp, fn, tn = 0, 0, 0, 0
boundary_ious, hausdorff_dists, times = [], [], []
with torch.no_grad():
for i, (x, y) in enumerate(loader):
if i > 50: break # Safety limit for speed
x = x.to(device)
y_true = y.cpu().numpy()
start = time.time()
preds = model(x)
if torch.cuda.is_available(): torch.cuda.synchronize()
times.append((time.time() - start) / x.shape[0])
y_pred = (torch.sigmoid(preds) > 0.5).float().cpu().numpy()
# Pixel Metrics
t = (y_true == 1); p = (y_pred == 1)
tp += (t & p).sum(); fp += (~t & p).sum(); fn += (t & ~p).sum(); tn += (~t & ~p).sum()
# Shape Metrics
for b in range(x.shape[0]):
pm, gm = y_pred[b,0].astype(bool), y_true[b,0].astype(bool)
pe = binary_dilation(pm)^pm; ge = binary_dilation(gm)^gm
i_ = (pe & ge).sum(); u_ = (pe | ge).sum()
boundary_ious.append(i_/u_ if u_ > 0 else 1.0)
if pm.sum()>0 and gm.sum()>0:
d1 = directed_hausdorff(np.argwhere(pm), np.argwhere(gm))[0]
d2 = directed_hausdorff(np.argwhere(gm), np.argwhere(pm))[0]
hausdorff_dists.append(max(d1, d2))
if i % 10 == 0: print(".", end="")
eps = 1e-6
res = {
"Pixel_Accuracy": (tp+tn)/(tp+tn+fp+fn+eps),
"IoU (Jaccard)": tp/(tp+fp+fn+eps),
"F1-Score (Dice)": 2*tp/(2*tp+fp+fn+eps),
"Precision": tp/(tp+fp+eps),
"Recall": tp/(tp+fn+eps),
"Boundary_IoU": np.mean(boundary_ious),
"Hausdorff_Px": np.mean(hausdorff_dists) if hausdorff_dists else 0.0,
"FPS": 1.0 / np.mean(times)
}
print("\n" + "="*40)
print(" FINAL METRICS REPORT (PARTIAL FT)")
print("="*40)
print(json.dumps({k: round(v, 4) for k, v in res.items()}, indent=4))
print("="*40)
# --- EXECUTION ---
model = PrithviPartialFT().to(DEVICE)
# 1. Try to Load Best Model (Weights Only)
# Priority: Best Model -> SWA -> Latest Checkpoint
model_path = os.path.join(SAVE_DIR, "best_model.pth")
if not os.path.exists(model_path):
print(" Best model not found, checking SWA...")
model_path = os.path.join(SAVE_DIR, "swa_model_final.pth")
if not os.path.exists(model_path):
print(" SWA not found, checking latest checkpoint...")
model_path = os.path.join(SAVE_DIR, "checkpoint.pth")
if os.path.exists(model_path):
print(f" Loading Weights from: {os.path.basename(model_path)}...", end=" ")
try:
ckpt = torch.load(model_path, map_location=DEVICE)
sd = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
new_sd = {k.replace('blocks', 'encoder.layers'): v for k,v in sd.items() if 'pos_embed' not in k or v.shape == model.pos_embed.shape}
model.load_state_dict(new_sd, strict=False)
print(" Loaded.")
# 2. Run Visualization & Metrics
if val_loader:
visualize_rgb(model, val_loader, DEVICE, num_samples=3)
evaluate_metrics(val_loader, model, DEVICE)
except Exception as e:
print(f" Failed to load weights: {e}")
else:
print(" No valid model file found. Please check your Drive folder.")
STARTING EVALUATION FOR: /content/drive/MyDrive/Prithvi_PartialFT_Results/ Dataset loaded successfully. Loading Weights from: best_model.pth... Loaded. Generating 3 Visualizations (RGB)...
Calculating Deep Metrics...
.
========================================
FINAL METRICS REPORT (PARTIAL FT)
========================================
{
"Pixel_Accuracy": 0.8864,
"IoU (Jaccard)": 0.8541,
"F1-Score (Dice)": 0.9213,
"Precision": 0.9047,
"Recall": 0.9386,
"Boundary_IoU": 0.0719,
"Hausdorff_Px": 12.877,
"FPS": 24.5055
}
========================================
In [11]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import os
from torch.utils.data import Dataset, DataLoader
# ==========================================
# CONFIGURATION
# ==========================================
SAVE_DIR = '/content/drive/MyDrive/Prithvi_PartialFT_Results/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================================
print(f" GENERATING 5 ROBUST SAMPLES FOR: {SAVE_DIR}")
# --- 1. DATASET LOADING ---
class PunjabWheatDataset(Dataset):
def __init__(self, x_path, y_path):
self.data = np.load(x_path, mmap_mode='r')
self.masks = np.load(y_path, mmap_mode='r')
def __len__(self): return len(self.data)
def __getitem__(self, idx):
return torch.from_numpy(self.data[idx].copy()).float(), torch.from_numpy(self.masks[idx].copy()).float()
val_x_path = os.path.join(SAVE_DIR, 'val_x.npy')
val_y_path = os.path.join(SAVE_DIR, 'val_y.npy')
if os.path.exists(val_x_path):
print(" Dataset loaded successfully.")
val_ds = PunjabWheatDataset(val_x_path, val_y_path)
val_loader = DataLoader(val_ds, batch_size=5, shuffle=True) # Batch size 5 to get enough samples at once
else:
print(" Dataset not found. Cannot run visualization.")
val_loader = None
# --- 2. MODEL DEFINITION ---
class PatchEmbed3D(nn.Module):
def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
super().__init__()
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, x): return self.proj(x).flatten(2).transpose(1, 2)
class PrithviModel(nn.Module):
def __init__(self):
super().__init__()
self.patch_embed = PatchEmbed3D()
self.register_buffer("pos_embed", torch.zeros(1, 3 * (224//16)**2, 768))
enc_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, dim_feedforward=3072, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=12)
in_dim = 768 * 3
self.decoder = nn.Sequential(
nn.Conv2d(in_dim, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(512, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(64, 1, 1)
)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
if self.pos_embed.shape[1] == x.shape[1]: x = x + self.pos_embed
x = self.encoder(x)
x = x.transpose(1, 2).view(B, 768, 3, 14, 14).reshape(B, -1, 14, 14)
return self.decoder(x)
# --- 3. ROBUST VISUALIZATION (Min-Max Fix) ---
def visualize_robust(model, loader, device, num_samples=5):
print(f"\n Visualizing {num_samples} Samples (Robust Mode)...")
model.eval()
try: x, y = next(iter(loader))
except: return
x, y = x.to(device), y.to(device)
with torch.no_grad():
preds = torch.sigmoid(model(x))
masks = (preds > 0.5).float().cpu().numpy()
x_np, y_np = x.cpu().numpy(), y.cpu().numpy()
actual = min(num_samples, x_np.shape[0])
# Scale figure height dynamically: 4 inches per sample
fig, axs = plt.subplots(actual, 3, figsize=(15, 4 * actual))
plt.suptitle(f"Prithvi Model: {actual} Test Samples", fontsize=16, y=1.02)
for i in range(actual):
# 1. Extract RGB (Red=2, Green=1, Blue=0)
rgb = np.stack([x_np[i,2,1], x_np[i,1,1], x_np[i,0,1]], axis=2)
# 2. Force Visibility (Min-Max Normalization)
# Allows negative/normalized values to be seen as colors
rgb_min, rgb_max = rgb.min(), rgb.max()
if rgb_max - rgb_min > 0:
rgb = (rgb - rgb_min) / (rgb_max - rgb_min)
else:
rgb = np.zeros_like(rgb)
ax = axs[i] if actual > 1 else axs
# Plot RGB
ax[0].imshow(rgb)
ax[0].set_title(f"Sample {i+1}: Input (RGB)")
# Plot Ground Truth
ax[1].imshow(y_np[i,0], cmap='gray')
ax[1].set_title("Ground Truth")
# Plot Prediction
ax[2].imshow(masks[i,0], cmap='gray')
ax[2].set_title("Prediction")
for a in ax: a.axis('off')
plt.tight_layout()
plt.show()
# --- EXECUTION ---
model = PrithviModel().to(DEVICE)
# Priority: Best -> SWA -> Checkpoint
model_path = os.path.join(SAVE_DIR, "best_model.pth")
if not os.path.exists(model_path): model_path = os.path.join(SAVE_DIR, "swa_model_final.pth")
if not os.path.exists(model_path): model_path = os.path.join(SAVE_DIR, "checkpoint.pth")
if os.path.exists(model_path):
print(f" Loading Weights from: {os.path.basename(model_path)}...", end=" ")
try:
ckpt = torch.load(model_path, map_location=DEVICE)
sd = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
new_sd = {k.replace('blocks', 'encoder.layers'): v for k,v in sd.items() if 'pos_embed' not in k or v.shape == model.pos_embed.shape}
model.load_state_dict(new_sd, strict=False)
print(" Loaded.")
if val_loader:
# CALL WITH 5 SAMPLES
visualize_robust(model, val_loader, DEVICE, num_samples=5)
except Exception as e:
print(f" Failed to load weights: {e}")
else:
print(" No valid model file found.")
GENERATING 5 ROBUST SAMPLES FOR: /content/drive/MyDrive/Prithvi_PartialFT_Results/ Dataset loaded successfully. Loading Weights from: best_model.pth... Loaded. Visualizing 5 Samples (Robust Mode)...