from flask import Flask, request, jsonify

app = Flask(__name__)

# --- Mock Model and Classification Logic ---
# In a real-world scenario, your model loading and prediction logic 
# would replace the 'mock_classify_trip' function.

def mock_load_model():
    """Mocks loading your trained ML model."""
    print("Loading ML model...")
    # Replace this with: model = joblib.load('your_model.pkl')
    return True # Simple return for mock purposes

def mock_classify_trip(features):
    """
    Mocks running the prediction for a single trip segment.
    
    Features are: [linearity, duration_h, avg_speed_mph]
    
    Logic: If average speed is high (e.g., > 60 MPH) and linearity is high (> 0.85), 
    it's likely Transit. Otherwise, it's local.
    """
    linearity = features[0]
    avg_speed_mph = features[2]
    
    # If average speed is very low, it's a stop.
    if avg_speed_mph < 1.0:
        return 'Stationary'
    
    # Example logic: High speed on the highway implies transit
    elif avg_speed_mph > 60.0 and linearity > 0.85:
        return 'Transit'

    # Low speed, localized movement is likely a destination activity
    else:
        return 'Local Destination'

# --- API Endpoint Update ---

@app.route('/predict_multi_trip', methods=['POST'])
def predict_multi_trip():
    data = request.get_json(force=True)
    
    if not data or 'trips' not in data or not isinstance(data['trips'], list):
        return jsonify({"error": "Invalid input structure. Expected JSON with 'trips' array."}), 400

    trips = data['trips']
    
    # --- 1. Process Individual Trips ---
    all_scores = []
    classified_trips = []
    
    for trip_data in trips:
        # Features is [linearity, duration_h, avg_speed_mph]
        features = trip_data.get('features', [0.0, 0.0, 0.0]) 
        
        # Determine the classification label (e.g., 'Transit', 'Local Destination')
        classification_label = mock_classify_trip(features)
        
        # Run your model's prediction score for this trip (0-1, where 1 means Transit)
        # Mock calculation: Use average speed as a proxy for transit likelihood
        transit_score = min(1.0, features[2] / 75.0) # Normalizes speed to a 0-1 score
        all_scores.append(transit_score)
        
        # Store the classification results for the PHP script
        classified_trips.append({
            "ml_classification": classification_label,
            "transit_score": round(transit_score, 4)
        })

    # --- 2. Calculate Daily Aggregate Score ---
    # The aggregate score is still needed for the main daily row insert.
    if all_scores:
        # For simplicity, let's use the maximum transit score of any trip as the daily aggregate.
        aggregate_score = max(all_scores)
        # You may use a weighted average or other aggregation method here.
    else:
        aggregate_score = 0.0

    # --- 3. Return the Combined Response ---
    return jsonify({
        "aggregate_transit_score": round(aggregate_score, 4),
        "classified_trips": classified_trips
    })

# --- Service Initialization ---
if __name__ == '__main__':
    mock_load_model()
    # In production, use Gunicorn or uWSGI. For testing:
    app.run(host='0.0.0.0', port=5000)
