Add authorization to digest-score

This commit is contained in:
Jackson Harper
2024-07-05 14:59:16 +08:00
parent 39df17fce0
commit a7ca844220
4 changed files with 59 additions and 5 deletions

View File

@ -16,13 +16,16 @@ from urllib.parse import urlparse
from datetime import datetime
import dateutil.parser
from google.cloud import storage
from features.user_history import FEATURE_COLUMNS
import concurrent.futures
from threading import Lock, RLock
from collections import ChainMap
import copy
from features.user_history import FEATURE_COLUMNS
from auth import user_token_required, admin_token_required
class ThreadSafeUserFeatures:
def __init__(self):
self._data = {}
@ -219,12 +222,14 @@ def metrics():
@app.route('/refresh', methods=['GET'])
@admin_token_required
def refresh():
refresh_data()
return jsonify({'OK': 'yes'}), 200
@app.route('/users/<user_id>/features', methods=['GET'])
@admin_token_required
def get_user_features(user_id):
result = {}
df_user = pd.DataFrame([{
@ -243,12 +248,13 @@ def get_user_features(user_id):
@app.route('/predict', methods=['POST'])
@user_token_required
def predict():
try:
data = request.get_json()
app.logger.info(f"predict scoring request: {data}")
user_id = data.get('user_id')
user_id = request.user_id
item_features = data.get('item_features')
if user_id is None:
@ -263,14 +269,13 @@ def predict():
@app.route('/batch', methods=['POST'])
@user_token_required
def batch():
start = timer()
try:
data = request.get_json()
app.logger.info(f"batch scoring request: {data}")
items = data.get('items')
user_id = data.get('user_id')
user_id = request.user_id
if user_id == None:
return jsonify({'error': 'no user_id supplied'}), 400
if len(items) > 101:

45
ml/digest-score/auth.py Normal file
View File

@ -0,0 +1,45 @@
import os
import jwt
from flask import request, jsonify
from functools import wraps
SECRET_KEY = os.getenv('JWT_SECRET')
ADMIN_SECRET_KEY = os.getenv('JWT_ADMIN_SECRET_KEY')
def user_token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = None
if 'Authorization' in request.headers:
print("request.headers['Authorization'].split(" ")[1]", request.headers['Authorization'].split(" ")[1])
token = request.headers['Authorization'].split(" ")[1]
if not token:
return jsonify({'message': 'Token is missing!'}), 401
try:
data = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
request.user_id = data['uid']
except jwt.ExpiredSignatureError:
return jsonify({'message': 'Token has expired!'}), 401
except jwt.InvalidTokenError:
return jsonify({'message': 'Token is invalid!'}), 401
return f(*args, **kwargs)
return decorated
def admin_token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = None
if 'Authorization' in request.headers:
token = request.headers['Authorization'].split(" ")[1]
if not token:
return jsonify({'message': 'Token is missing!'}), 401
try:
data = jwt.decode(token, ADMIN_SECRET_KEY, algorithms=["HS256"])
if data['role'] != 'admin':
return jsonify({'message': 'Admin token required!'}), 403
except jwt.ExpiredSignatureError:
return jsonify({'message': 'Token has expired!'}), 401
except jwt.InvalidTokenError:
return jsonify({'message': 'Token is invalid!'}), 401
return f(*args, **kwargs)
return decorated

View File

@ -9,5 +9,6 @@ sklearn2pmml
sqlalchemy
pyarrow
requests
PyJWT
prometheus_client
xgboost==2.1.0

View File

@ -3,6 +3,7 @@ import client from 'prom-client'
import { env } from '../env'
import { registerMetric } from '../prometheus'
import { logError } from '../utils/logger'
import { createWebAuthToken } from '../routers/auth/jwt_helpers'
export interface Feature {
library_item_id?: string
@ -77,10 +78,12 @@ class ScoreClientImpl implements ScoreClient {
async getScores(data: ScoreApiRequestBody): Promise<ScoreApiResponse> {
const start = Date.now()
const authToken = createWebAuthToken(data.user_id)
try {
const response = await axios.post<ScoreApiResponse>(this.apiUrl, data, {
headers: {
Authorization: `Bearer ${authToken}`,
'Content-Type': 'application/json',
},
timeout: 5000,