diff --git a/ml/digest-score/app.py b/ml/digest-score/app.py index 0f4111937..f3c4dcd11 100644 --- a/ml/digest-score/app.py +++ b/ml/digest-score/app.py @@ -16,13 +16,16 @@ from urllib.parse import urlparse from datetime import datetime import dateutil.parser from google.cloud import storage -from features.user_history import FEATURE_COLUMNS import concurrent.futures from threading import Lock, RLock from collections import ChainMap import copy +from features.user_history import FEATURE_COLUMNS +from auth import user_token_required, admin_token_required + + class ThreadSafeUserFeatures: def __init__(self): self._data = {} @@ -219,12 +222,14 @@ def metrics(): @app.route('/refresh', methods=['GET']) +@admin_token_required def refresh(): refresh_data() return jsonify({'OK': 'yes'}), 200 @app.route('/users//features', methods=['GET']) +@admin_token_required def get_user_features(user_id): result = {} df_user = pd.DataFrame([{ @@ -243,12 +248,13 @@ def get_user_features(user_id): @app.route('/predict', methods=['POST']) +@user_token_required def predict(): try: data = request.get_json() app.logger.info(f"predict scoring request: {data}") - user_id = data.get('user_id') + user_id = request.user_id item_features = data.get('item_features') if user_id is None: @@ -263,14 +269,13 @@ def predict(): @app.route('/batch', methods=['POST']) +@user_token_required def batch(): start = timer() try: data = request.get_json() - app.logger.info(f"batch scoring request: {data}") - items = data.get('items') - user_id = data.get('user_id') + user_id = request.user_id if user_id == None: return jsonify({'error': 'no user_id supplied'}), 400 if len(items) > 101: diff --git a/ml/digest-score/auth.py b/ml/digest-score/auth.py new file mode 100644 index 000000000..7dbfcf5ca --- /dev/null +++ b/ml/digest-score/auth.py @@ -0,0 +1,45 @@ +import os +import jwt +from flask import request, jsonify +from functools import wraps + +SECRET_KEY = os.getenv('JWT_SECRET') +ADMIN_SECRET_KEY = os.getenv('JWT_ADMIN_SECRET_KEY') + +def user_token_required(f): + @wraps(f) + def decorated(*args, **kwargs): + token = None + if 'Authorization' in request.headers: + print("request.headers['Authorization'].split(" ")[1]", request.headers['Authorization'].split(" ")[1]) + token = request.headers['Authorization'].split(" ")[1] + if not token: + return jsonify({'message': 'Token is missing!'}), 401 + try: + data = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) + request.user_id = data['uid'] + except jwt.ExpiredSignatureError: + return jsonify({'message': 'Token has expired!'}), 401 + except jwt.InvalidTokenError: + return jsonify({'message': 'Token is invalid!'}), 401 + return f(*args, **kwargs) + return decorated + +def admin_token_required(f): + @wraps(f) + def decorated(*args, **kwargs): + token = None + if 'Authorization' in request.headers: + token = request.headers['Authorization'].split(" ")[1] + if not token: + return jsonify({'message': 'Token is missing!'}), 401 + try: + data = jwt.decode(token, ADMIN_SECRET_KEY, algorithms=["HS256"]) + if data['role'] != 'admin': + return jsonify({'message': 'Admin token required!'}), 403 + except jwt.ExpiredSignatureError: + return jsonify({'message': 'Token has expired!'}), 401 + except jwt.InvalidTokenError: + return jsonify({'message': 'Token is invalid!'}), 401 + return f(*args, **kwargs) + return decorated \ No newline at end of file diff --git a/ml/digest-score/requirements.txt b/ml/digest-score/requirements.txt index b9ca8e817..b7fbe1b94 100644 --- a/ml/digest-score/requirements.txt +++ b/ml/digest-score/requirements.txt @@ -9,5 +9,6 @@ sklearn2pmml sqlalchemy pyarrow requests +PyJWT prometheus_client xgboost==2.1.0 diff --git a/packages/api/src/services/score.ts b/packages/api/src/services/score.ts index e91384a8c..71a744e04 100644 --- a/packages/api/src/services/score.ts +++ b/packages/api/src/services/score.ts @@ -3,6 +3,7 @@ import client from 'prom-client' import { env } from '../env' import { registerMetric } from '../prometheus' import { logError } from '../utils/logger' +import { createWebAuthToken } from '../routers/auth/jwt_helpers' export interface Feature { library_item_id?: string @@ -77,10 +78,12 @@ class ScoreClientImpl implements ScoreClient { async getScores(data: ScoreApiRequestBody): Promise { const start = Date.now() + const authToken = createWebAuthToken(data.user_id) try { const response = await axios.post(this.apiUrl, data, { headers: { + Authorization: `Bearer ${authToken}`, 'Content-Type': 'application/json', }, timeout: 5000,