Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import ru.yandex.practicum.filmorate.model.Activity;
import ru.yandex.practicum.filmorate.model.Film;
import ru.yandex.practicum.filmorate.model.User;
import ru.yandex.practicum.filmorate.service.RecommendationService;
import ru.yandex.practicum.filmorate.service.UserServiceImpl;

import java.util.Collection;
Expand All @@ -20,6 +22,7 @@
@Validated
public class UserController {
private final UserServiceImpl userService;
private final RecommendationService recommendationService;

@GetMapping("/{userId}/feed")
@ResponseStatus(HttpStatus.OK)
Expand Down Expand Up @@ -80,4 +83,10 @@ public Collection<User> getAllFriends(@PathVariable("userId") long userId) {
public Collection<User> getCommonFriends(@PathVariable("userId") long userId, @PathVariable("otherId") long otherId) {
return userService.getCommonFriends(userId, otherId);
}

@GetMapping("/{userId}/recommendations")
@ResponseStatus(HttpStatus.OK)
public List<Film> getUserRecommendations(@PathVariable("userId") long userId) {
return recommendationService.getRecommendations(userId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import ru.yandex.practicum.filmorate.model.User;

import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.*;
import java.util.stream.Stream;

@Repository
Expand All @@ -26,15 +25,21 @@ public class JdbcUserRepository implements UserRepository {

private static final String CREATE_USER_QUERY = "INSERT INTO users (login, email, name, birthday) VALUES(:login,:email,:name,:birthday)";
private static final String UPDATE_USER_QUERY = "UPDATE users SET login=:login, email=:email, name=:name, birthday=:birthday WHERE user_id=:user_id";
private static final String DELETE_USER_QUERY = "DELETE FROM users WHERE user_id = :user_id";
private static final String GET_USER_BY_ID_QUERY = "SELECT * FROM users WHERE user_id = :user_id";
private static final String DELETE_USER_QUERY = "DELETE FROM users WHERE user_id=:user_id";
private static final String GET_USER_BY_ID_QUERY = "SELECT * FROM users WHERE user_id=:user_id";
private static final String GET_ALL_USERS_QUERY = "SELECT * FROM users";
private static final String ADD_FRIEND_QUERY = "INSERT INTO friends(user_id, friend_id) VALUES(:user_id,:friend_id)";
private static final String DELETE_FRIEND_QUERY = "DELETE FROM friends WHERE user_id=:user_id AND friend_id=:friend_id";
private static final String GET_USER_FRIENDS_QUERY = "SELECT * FROM users WHERE user_id IN (SELECT friend_id FROM friends WHERE user_id = :user_id)";
private static final String GET_USER_FRIENDS_QUERY = "SELECT * FROM users WHERE user_id IN (SELECT friend_id FROM friends WHERE user_id=:user_id)";
private static final String FIND_COMMON_FRIENDS = "SELECT u.* FROM users u JOIN friends f1 ON u.user_id = f1.friend_id "
+ "JOIN friends f2 ON u.user_id = f2.friend_id WHERE f1.user_id = :user_id AND f2.user_id = :other_id";
private static final String FIND_USERS_BY_IDS_QUERY = "SELECT * FROM users WHERE user_id IN (:ids)";
private static final String GET_USER_LIKED_FILMS_QUERY = """
SELECT f.film_id
FROM likes l
WHERE l.user_id =:user_id
""";
private static final String GET_ALL_USERS_LIKES = "SELECT user_id, film_id FROM likes";

private static final String ACTIVITY_GENERAL =
"INSERT INTO activity (userId, entityId, eventType, operation, timestamp) VALUES(:userId, :entityId, '";
Expand Down Expand Up @@ -154,4 +159,24 @@ public List<User> getUsersByIds(List<Long> ids) {
params.addValue("ids", ids);
return jdbc.query(FIND_USERS_BY_IDS_QUERY, params, mapper);
}

@Override
public Set<Long> getUserLikedFilms(long userId) {
MapSqlParameterSource params = new MapSqlParameterSource("user_id", userId);
List<Long> result = jdbc.queryForList(GET_USER_LIKED_FILMS_QUERY, params, Long.class);
return new HashSet<>(result);
}

@Override
public Map<Long, Set<Long>> getAllUserLikes() {
List<Map<String, Object>> rows = jdbc.queryForList(GET_ALL_USERS_LIKES, new MapSqlParameterSource());

Map<Long, Set<Long>> userLikes = new HashMap<>();
for (Map<String, Object> row : rows) {
Long userId = ((Number) row.get("user_id")).longValue();
Long filmId = ((Number) row.get("film_id")).longValue();
userLikes.computeIfAbsent(userId, k -> new HashSet<>()).add(filmId);
}
return userLikes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import ru.yandex.practicum.filmorate.model.User;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

public interface UserRepository {
List<Activity> getActivityById(long activityId);
Expand All @@ -28,4 +30,8 @@ public interface UserRepository {
List<User> getUserFriends(long userId);

List<User> getCommonFriends(long userId, long otherId);

Set<Long> getUserLikedFilms(long userId);

Map<Long, Set<Long>> getAllUserLikes();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package ru.yandex.practicum.filmorate.service;

import ru.yandex.practicum.filmorate.model.Film;
import java.util.List;

public interface RecommendationService {
List<Film> getRecommendations(long userId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package ru.yandex.practicum.filmorate.service;

import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import ru.yandex.practicum.filmorate.dal.JdbcUserRepository;
import ru.yandex.practicum.filmorate.dal.JdbcFilmRepository;
import ru.yandex.practicum.filmorate.model.Film;

import java.util.*;
import java.util.stream.Collectors;

@Service
@Slf4j
@RequiredArgsConstructor
public class RecommendationServiceImpl implements RecommendationService {
private final JdbcUserRepository userRepository;
private final JdbcFilmRepository filmRepository;
private final SlopeOne slopeOne;

@PostConstruct
public void init() {
slopeOne.buildDifferences(Collections.emptyMap());
}

private void rebuildModel() {
var raw = userRepository.getAllUserLikes();
var data = raw.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
e -> e.getValue().stream().collect(Collectors.toMap(f -> f, f -> 1.0))
));
slopeOne.buildDifferences(data);
}

@Override
public List<Film> getRecommendations(long userId) {
var myLikes = userRepository.getAllUserLikes()
.getOrDefault(userId, Collections.emptySet());
if (myLikes.isEmpty()) {
return Collections.emptyList();
}

rebuildModel();

var predictions = slopeOne.predictRatings(
myLikes.stream().collect(Collectors.toMap(f -> f, f -> 1.0))
);
predictions.keySet().removeAll(myLikes);

if (predictions.isEmpty()) {
return Collections.emptyList();
}

var recommendations = predictions.entrySet().stream()
.sorted(Map.Entry.<Long, Double>comparingByValue().reversed())
.map(e -> filmRepository.getFilmById(e.getKey()))
.flatMap(Optional::stream)
.collect(Collectors.toList());

filmRepository.connectGenres(recommendations);
filmRepository.connectDirectors(recommendations);

return recommendations;
}
}
72 changes: 72 additions & 0 deletions src/main/java/ru/yandex/practicum/filmorate/service/SlopeOne.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package ru.yandex.practicum.filmorate.service;

import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.Map;

@Component
public class SlopeOne {
private final Map<Long, Map<Long, Double>> diffMatrix = new HashMap<>();

private final Map<Long, Map<Long, Integer>> freqMatrix = new HashMap<>();

public void buildDifferences(Map<Long, Map<Long, Double>> data) {
for (Map<Long, Double> userRatings : data.values()) {
for (Map.Entry<Long, Double> e1 : userRatings.entrySet()) {
long i = e1.getKey();
double r1 = e1.getValue();

diffMatrix.computeIfAbsent(i, k -> new HashMap<>());
freqMatrix.computeIfAbsent(i, k -> new HashMap<>());

for (Map.Entry<Long, Double> e2 : userRatings.entrySet()) {
long j = e2.getKey();
double r2 = e2.getValue();


diffMatrix.get(i).merge(j, r1 - r2, Double::sum);

freqMatrix.get(i).merge(j, 1, Integer::sum);
}
}
}

for (Long i : diffMatrix.keySet()) {
for (Long j : diffMatrix.get(i).keySet()) {
double totalDiff = diffMatrix.get(i).get(j);
int count = freqMatrix.get(i).get(j);
diffMatrix.get(i).put(j, totalDiff / count);
}
}
}

public Map<Long, Double> predictRatings(Map<Long, Double> userRatings) {
Map<Long, Double> predictions = new HashMap<>();
Map<Long, Integer> counts = new HashMap<>();

for (Map.Entry<Long, Double> entry : userRatings.entrySet()) {
long j = entry.getKey();
double rj = entry.getValue();

for (Map.Entry<Long, Map<Long, Double>> row : diffMatrix.entrySet()) {
long i = row.getKey();
Map<Long, Double> diffs = row.getValue();

if (diffs.containsKey(j)) {
double diff = diffs.get(j);
int freq = freqMatrix.get(i).get(j);

predictions.merge(i, (diff + rj) * freq, Double::sum);
counts.merge(i, freq, Integer::sum);
}
}
}

for (Long i : predictions.keySet()) {
predictions.put(i, predictions.get(i) / counts.get(i));
}

return predictions;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ public void deleteFriend(long userId, long friendId) {

@Override
public List<User> getAllFriends(long userId) {
User user = getUserById(userId);
if (user == null) {
throw new NotFoundException("Пользователь с id = " + userId + " не найден");
}
getUserById(userId);
return jdbcUserRepository.getUserFriends(userId);
}

Expand Down